Implement string equality operator using IRRT and optimise LLVM implementation #561

Merged
sb10q merged 9 commits from ramtej/nac3:feature/string-equality into master 2024-12-30 13:02:09 +08:00
5 changed files with 95 additions and 88 deletions

View File

@ -4,7 +4,9 @@
#include "irrt/ndarray.hpp"
#include "irrt/range.hpp"
#include "irrt/slice.hpp"
#include "irrt/string.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp"
#include "irrt/ndarray/indexing.hpp"
#include "irrt/string.hpp"

View File

@ -0,0 +1,23 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
ramtej marked this conversation as resolved Outdated
Outdated
Review

I think this case of pointer equality is rare enough that it does not warrant its special handling and corresponding IR bloat. LLVM is already sluggish enough processing all that IR.

It's also potentially incorrect since you could call the function with the same pointer and different lengths and it would return true.

So remove this line.

I think this case of pointer equality is rare enough that it does not warrant its special handling and corresponding IR bloat. LLVM is already sluggish enough processing all that IR. It's also potentially incorrect since you could call the function with the same pointer and different lengths and it would return true. So remove this line.
if (len1 != len2){
return 0;
}
return (__builtin_memcmp(str1, str2, static_cast<SizeT>(len1)) == 0) ? 1 : 0;
Outdated
Review

Is IRRT really required at this point?

Is IRRT really required at this point?

Yes, __builtin_strncmp isn't exposed in LLVM IR so we will need clang to emit the equivalent bitcode for us.

Yes, `__builtin_strncmp` isn't exposed in LLVM IR so we will need `clang` to emit the equivalent bitcode for us.
Outdated
Review

Isn't clang also just a LLVM user?
How does it do it?

Isn't clang also just a LLVM user? How does it do it?
Outdated
Review

Also, NAC3 is just using regular Pascal-style strings, isn't it?
If so, then the correct function to use is memcmp. strncmp is intrinsically slower as it adds a check for NUL characters which is unnecessary and may even lead to incorrect behavior.

Also, NAC3 is just using regular Pascal-style strings, isn't it? If so, then the correct function to use is ``memcmp``. ``strncmp`` is intrinsically slower as it adds a check for NUL characters which is unnecessary and may even lead to incorrect behavior.
}
ramtej marked this conversation as resolved Outdated
Outdated
Review

Isn't that just memcmp but implemented inefficiently?

Isn't that just memcmp but implemented inefficiently?

LLVM has an intrinsic __builtin_memcmp for this exact purpose: https://clang.llvm.org/docs/LanguageExtensions.html#string-builtins

LLVM has an intrinsic `__builtin_memcmp` for this exact purpose: https://clang.llvm.org/docs/LanguageExtensions.html#string-builtins

In fact, LLVM has __builtin_strncmp. If that's the case, the entire IRRT can just be implemented in terms of that intrinsic function.

In fact, LLVM has `__builtin_strncmp`. If that's the case, the entire IRRT can just be implemented in terms of that intrinsic function.
} // namespace
extern "C" {
uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) {
return __nac3_str_eq_impl<uint32_t>(str1, len1, str2, len2);
}
ramtej marked this conversation as resolved Outdated
Outdated
Review

uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening

uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening

str is implemented as struct { i8*, usize } where usize is size_t. You will need two C functions for this, one with 32-bit and one with 64-bit lengths.

`str` is implemented as `struct { i8*, usize }` where `usize` is `size_t`. You will need two C functions for this, one with 32-bit and one with 64-bit lengths.
uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
}
}

View File

@ -24,7 +24,7 @@ use super::{
irrt::*,
llvm_intrinsics::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic,
call_memcpy_generic,
},
macros::codegen_unreachable,
need_sret, numpy,
@ -2045,111 +2045,43 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
} else if left_ty == ctx.primitives.str {
assert!(ctx.unifier.unioned(left_ty, right_ty));
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let lhs = lhs.into_struct_value();
let rhs = rhs.into_struct_value();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(plhs, lhs).unwrap();
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(prhs, rhs).unwrap();
let lhs_ptr = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
let lhs_len = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
None,
).into_int_value();
let rhs_ptr = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
let rhs_len = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
None,
).into_int_value();
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
let current_bb = ctx.builder.get_insert_block().unwrap();
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
ctx.builder.position_at_end(post_foreach_cmp);
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
ctx.builder.position_at_end(current_bb);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let lhs_char = {
let plhs_data = ctx.build_in_bounds_gep_and_load(
plhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
ctx.build_in_bounds_gep_and_load(
plhs_data,
&[i],
None
).into_int_value()
};
let rhs_char = {
let prhs_data = ctx.build_in_bounds_gep_and_load(
prhs,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_pointer_value();
ctx.build_in_bounds_gep_and_load(
prhs_data,
&[i],
None
).into_int_value()
};
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
},
|_, ctx| {
let bb = ctx.builder.get_insert_block().unwrap();
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
Ok(())
},
|_, _| Ok(()),
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let bb = ctx.builder.get_insert_block().unwrap();
let is_len_eq = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_len,
rhs_len,
"",
).unwrap();
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
ctx.builder.position_at_end(post_foreach_cmp);
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
// Invert the final value if __ne__
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
if *op == Cmpop::NotEq {
ctx.builder.build_not(cmp_phi, "").unwrap()
ctx.builder.build_not(result, "").unwrap()
} else {
cmp_phi
result
}
} else if [left_ty, right_ty]
.iter()
Outdated
Review

i32?

i32?

Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction:

"The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant."

Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction: "The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant."

so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format

so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format

What about llvm_usize for the first gep and i32 for the second?

What about `llvm_usize` for the first `gep` and `i32` for the second?

yea, that worked too, would it be better for me to use that implementation then?

yea, that worked too, would it be better for me to use that implementation then?

Yes, please use that implementation.

Yes, please use that implementation.

View File

@ -15,12 +15,14 @@ pub use list::*;
pub use math::*;
pub use range::*;
pub use slice::*;
pub use string::*;
mod list;
mod math;
pub mod ndarray;
mod range;
mod slice;
mod string;
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {

View File

@ -0,0 +1,48 @@
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
use itertools::Either;
use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator};
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
str1_ptr: PointerValue<'ctx>,
str1_len: IntValue<'ctx>,
str2_ptr: PointerValue<'ctx>,
str2_len: IntValue<'ctx>,
) -> IntValue<'ctx> {
let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() {
32 => ("nac3_str_eq", ctx.ctx.i32_type()),
64 => ("nac3_str_eq64", ctx.ctx.i64_type()),
ramtej marked this conversation as resolved Outdated
Outdated
Review

And when is nac3_str_eq64 called?

And when is ``nac3_str_eq64`` called?
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let func = ctx.module.get_function(func_name).unwrap_or_else(|| {
ctx.module.add_function(
func_name,
return_type.fn_type(
&[
str1_ptr.get_type().into(),
str1_len.get_type().into(),
str2_ptr.get_type().into(),
str2_len.get_type().into(),
],
false,
),
None,
)
});
let result = ctx
.builder
.build_call(
func,
&[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
"str_eq_call",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, result)
}