Implement string equality operator using IRRT and optimise LLVM implementation #561

Merged
sb10q merged 9 commits from ramtej/nac3:feature/string-equality into master 2024-12-30 13:02:09 +08:00
6 changed files with 94 additions and 87 deletions
Showing only changes of commit e13d753329 - Show all commits

View File

@ -3,3 +3,4 @@
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/slice.hpp"
#include "irrt/string.hpp"

View File

@ -0,0 +1,22 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
if (str1 == str2) return 1;
ramtej marked this conversation as resolved Outdated
Outdated
Review

I think this case of pointer equality is rare enough that it does not warrant its special handling and corresponding IR bloat. LLVM is already sluggish enough processing all that IR.

It's also potentially incorrect since you could call the function with the same pointer and different lengths and it would return true.

So remove this line.

I think this case of pointer equality is rare enough that it does not warrant its special handling and corresponding IR bloat. LLVM is already sluggish enough processing all that IR. It's also potentially incorrect since you could call the function with the same pointer and different lengths and it would return true. So remove this line.
if (len1 != len2) return 0;
for (SizeT i = 0; i < len1; ++i) {
if (static_cast<unsigned char>(str1[i]) != static_cast<unsigned char>(str2[i])) {
return 0;
Outdated
Review

Is IRRT really required at this point?

Is IRRT really required at this point?

Yes, __builtin_strncmp isn't exposed in LLVM IR so we will need clang to emit the equivalent bitcode for us.

Yes, `__builtin_strncmp` isn't exposed in LLVM IR so we will need `clang` to emit the equivalent bitcode for us.
Outdated
Review

Isn't clang also just a LLVM user?
How does it do it?

Isn't clang also just a LLVM user? How does it do it?
Outdated
Review

Also, NAC3 is just using regular Pascal-style strings, isn't it?
If so, then the correct function to use is memcmp. strncmp is intrinsically slower as it adds a check for NUL characters which is unnecessary and may even lead to incorrect behavior.

Also, NAC3 is just using regular Pascal-style strings, isn't it? If so, then the correct function to use is ``memcmp``. ``strncmp`` is intrinsically slower as it adds a check for NUL characters which is unnecessary and may even lead to incorrect behavior.
}
ramtej marked this conversation as resolved Outdated
Outdated
Review

Isn't that just memcmp but implemented inefficiently?

Isn't that just memcmp but implemented inefficiently?

LLVM has an intrinsic __builtin_memcmp for this exact purpose: https://clang.llvm.org/docs/LanguageExtensions.html#string-builtins

LLVM has an intrinsic `__builtin_memcmp` for this exact purpose: https://clang.llvm.org/docs/LanguageExtensions.html#string-builtins

In fact, LLVM has __builtin_strncmp. If that's the case, the entire IRRT can just be implemented in terms of that intrinsic function.

In fact, LLVM has `__builtin_strncmp`. If that's the case, the entire IRRT can just be implemented in terms of that intrinsic function.
}
return 1;
}
} // namespace
extern "C" {
int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
ramtej marked this conversation as resolved Outdated
Outdated
Review

uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening

uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening

str is implemented as struct { i8*, usize } where usize is size_t. You will need two C functions for this, one with 32-bit and one with 64-bit lengths.

`str` is implemented as `struct { i8*, usize }` where `usize` is `size_t`. You will need two C functions for this, one with 32-bit and one with 64-bit lengths.
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
}
}

View File

@ -24,7 +24,7 @@ use super::{
irrt::*,
llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic,
call_memcpy_generic,
},
macros::codegen_unreachable,
need_sret, numpy,
@ -2072,111 +2072,42 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
} else if left_ty == ctx.primitives.str {
assert!(ctx.unifier.unioned(left_ty, right_ty));
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let lhs = lhs.into_struct_value();
let rhs = rhs.into_struct_value();
let llvm_i32 = ctx.ctx.i32_type();
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(plhs, lhs).unwrap();
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(prhs, rhs).unwrap();
let lhs_ptr = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
Outdated
Review

i32?

i32?

Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction:

"The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant."

Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction: "The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant."

so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format

so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format

What about llvm_usize for the first gep and i32 for the second?

What about `llvm_usize` for the first `gep` and `i32` for the second?

yea, that worked too, would it be better for me to use that implementation then?

yea, that worked too, would it be better for me to use that implementation then?

Yes, please use that implementation.

Yes, please use that implementation.
None,
).into_pointer_value();
let lhs_len = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
None,
).into_int_value();
let rhs_ptr = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
let rhs_len = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
None,
).into_int_value();
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
let current_bb = ctx.builder.get_insert_block().unwrap();
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
ctx.builder.position_at_end(post_foreach_cmp);
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
ctx.builder.position_at_end(current_bb);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let lhs_char = {
let plhs_data = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
ctx.build_in_bounds_gep_and_load(
plhs_data,
&[i],
None
).into_int_value()
};
let rhs_char = {
let prhs_data = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
ctx.build_in_bounds_gep_and_load(
prhs_data,
&[i],
None
).into_int_value()
};
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
},
|_, ctx| {
let bb = ctx.builder.get_insert_block().unwrap();
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
Ok(())
},
|_, _| Ok(()),
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let bb = ctx.builder.get_insert_block().unwrap();
let is_len_eq = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_len,
rhs_len,
"",
).unwrap();
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
ctx.builder.position_at_end(post_foreach_cmp);
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
// Invert the final value if __ne__
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
if *op == Cmpop::NotEq {
ctx.builder.build_not(cmp_phi, "").unwrap()
ctx.builder.build_not(result, "").unwrap()
} else {
cmp_phi
result
}
} else if [left_ty, right_ty]
.iter()

View File

@ -15,11 +15,13 @@ pub use list::*;
pub use math::*;
pub use ndarray::*;
pub use slice::*;
pub use string::*;
mod list;
mod math;
mod ndarray;
mod slice;
mod string;
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {

View File

@ -0,0 +1,45 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue},
AddressSpace,
};
use itertools::Either;
use crate::codegen::{CodeGenContext, CodeGenerator};
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
str1_ptr: PointerValue<'ctx>,
str1_len: IntValue<'ctx>,
str2_ptr: PointerValue<'ctx>,
str2_len: IntValue<'ctx>,
) -> IntValue<'ctx> {
ramtej marked this conversation as resolved Outdated
Outdated
Review

And when is nac3_str_eq64 called?

And when is ``nac3_str_eq64`` called?
let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| {
let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type();
let fn_type = i32_type.fn_type(
&[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()],
false,
);
ctx.module.add_function("nac3_str_eq", fn_type, None)
});
let result = ctx
.builder
.build_call(
string_eq_fn,
&[
str1_ptr.into(),
ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(),
str2_ptr.into(),
ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(),
],
"string_eq",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, result)
}

View File

@ -11,6 +11,9 @@ def str_eq():
output_bool("a" == "a")
output_bool("test string" == "test string")
output_bool("test string1" == "test string2")
output_bool("test" == "testing")
output_bool("abcd" == "abdc")
output_bool(" " == " ")
def str_ne():
@ -21,10 +24,13 @@ def str_ne():
output_bool("a" != "a")
output_bool("test string" != "test string")
output_bool("test string1" != "test string2")
output_bool("test" != "testing")
output_bool("abcd" != "abdc")
output_bool(" " != " ")
def run() -> int32:
str_eq()
str_ne()
return 0
return 0
Outdated
Review

Again I don't see the point of these test changes.

Again I don't see the point of these test changes.