Implement string equality operator using IRRT and optimise LLVM implementation #561
@ -4,7 +4,9 @@
|
||||
#include "irrt/ndarray.hpp"
|
||||
#include "irrt/range.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
#include "irrt/string.hpp"
|
||||
#include "irrt/ndarray/basic.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
#include "irrt/ndarray/iter.hpp"
|
||||
#include "irrt/ndarray/indexing.hpp"
|
||||
#include "irrt/string.hpp"
|
||||
|
23
nac3core/irrt/irrt/string.hpp
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
|
||||
ramtej marked this conversation as resolved
Outdated
|
||||
if (len1 != len2){
|
||||
return 0;
|
||||
}
|
||||
return (__builtin_memcmp(str1, str2, static_cast<SizeT>(len1)) == 0) ? 1 : 0;
|
||||
sb10q
commented
Is IRRT really required at this point? Is IRRT really required at this point?
derppening
commented
Yes, Yes, `__builtin_strncmp` isn't exposed in LLVM IR so we will need `clang` to emit the equivalent bitcode for us.
sb10q
commented
Isn't clang also just a LLVM user? Isn't clang also just a LLVM user?
How does it do it?
sb10q
commented
Also, NAC3 is just using regular Pascal-style strings, isn't it? Also, NAC3 is just using regular Pascal-style strings, isn't it?
If so, then the correct function to use is ``memcmp``. ``strncmp`` is intrinsically slower as it adds a check for NUL characters which is unnecessary and may even lead to incorrect behavior.
|
||||
}
|
||||
ramtej marked this conversation as resolved
Outdated
sb10q
commented
Isn't that just memcmp but implemented inefficiently? Isn't that just memcmp but implemented inefficiently?
derppening
commented
LLVM has an intrinsic LLVM has an intrinsic `__builtin_memcmp` for this exact purpose: https://clang.llvm.org/docs/LanguageExtensions.html#string-builtins
derppening
commented
In fact, LLVM has In fact, LLVM has `__builtin_strncmp`. If that's the case, the entire IRRT can just be implemented in terms of that intrinsic function.
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) {
|
||||
return __nac3_str_eq_impl<uint32_t>(str1, len1, str2, len2);
|
||||
}
|
||||
|
||||
ramtej marked this conversation as resolved
Outdated
sb10q
commented
uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening
derppening
commented
`str` is implemented as `struct { i8*, usize }` where `usize` is `size_t`. You will need two C functions for this, one with 32-bit and one with 64-bit lengths.
|
||||
uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
|
||||
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
|
||||
}
|
||||
}
|
@ -24,7 +24,7 @@ use super::{
|
||||
irrt::*,
|
||||
llvm_intrinsics::{
|
||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||
call_int_umin, call_memcpy_generic,
|
||||
call_memcpy_generic,
|
||||
},
|
||||
macros::codegen_unreachable,
|
||||
need_sret, numpy,
|
||||
@ -2045,111 +2045,43 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
} else if left_ty == ctx.primitives.str {
|
||||
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let lhs = lhs.into_struct_value();
|
||||
let rhs = rhs.into_struct_value();
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||
ctx.builder.build_store(plhs, lhs).unwrap();
|
||||
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||
ctx.builder.build_store(prhs, rhs).unwrap();
|
||||
|
||||
let lhs_ptr = ctx.build_in_bounds_gep_and_load(
|
||||
plhs,
|
||||
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
||||
plhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
let rhs_ptr = ctx.build_in_bounds_gep_and_load(
|
||||
prhs,
|
||||
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
||||
prhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
|
||||
ctx.builder.position_at_end(current_bb);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(len, false),
|
||||
|generator, ctx, _, i| {
|
||||
let lhs_char = {
|
||||
let plhs_data = ctx.build_in_bounds_gep_and_load(
|
||||
plhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
plhs_data,
|
||||
&[i],
|
||||
None
|
||||
).into_int_value()
|
||||
};
|
||||
let rhs_char = {
|
||||
let prhs_data = ctx.build_in_bounds_gep_and_load(
|
||||
prhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
prhs_data,
|
||||
&[i],
|
||||
None
|
||||
).into_int_value()
|
||||
};
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
|
||||
},
|
||||
|_, ctx| {
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
let is_len_eq = ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
lhs_len,
|
||||
rhs_len,
|
||||
"",
|
||||
).unwrap();
|
||||
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
|
||||
|
||||
// Invert the final value if __ne__
|
||||
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
|
||||
if *op == Cmpop::NotEq {
|
||||
ctx.builder.build_not(cmp_phi, "").unwrap()
|
||||
ctx.builder.build_not(result, "").unwrap()
|
||||
} else {
|
||||
cmp_phi
|
||||
result
|
||||
}
|
||||
} else if [left_ty, right_ty]
|
||||
.iter()
|
||||
sb10q
commented
i32? i32?
ramtej
commented
Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction: "The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant." Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction:
"The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant."
ramtej
commented
so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format
derppening
commented
What about What about `llvm_usize` for the first `gep` and `i32` for the second?
ramtej
commented
yea, that worked too, would it be better for me to use that implementation then? yea, that worked too, would it be better for me to use that implementation then?
derppening
commented
Yes, please use that implementation. Yes, please use that implementation.
|
||||
|
@ -15,12 +15,14 @@ pub use list::*;
|
||||
pub use math::*;
|
||||
pub use range::*;
|
||||
pub use slice::*;
|
||||
pub use string::*;
|
||||
|
||||
mod list;
|
||||
mod math;
|
||||
pub mod ndarray;
|
||||
mod range;
|
||||
mod slice;
|
||||
mod string;
|
||||
|
||||
#[must_use]
|
||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||
|
48
nac3core/src/codegen/irrt/string.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
|
||||
use itertools::Either;
|
||||
|
||||
use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator};
|
||||
|
||||
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
||||
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
str1_ptr: PointerValue<'ctx>,
|
||||
str1_len: IntValue<'ctx>,
|
||||
str2_ptr: PointerValue<'ctx>,
|
||||
str2_len: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() {
|
||||
32 => ("nac3_str_eq", ctx.ctx.i32_type()),
|
||||
64 => ("nac3_str_eq64", ctx.ctx.i64_type()),
|
||||
ramtej marked this conversation as resolved
Outdated
sb10q
commented
And when is And when is ``nac3_str_eq64`` called?
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
|
||||
let func = ctx.module.get_function(func_name).unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
func_name,
|
||||
return_type.fn_type(
|
||||
&[
|
||||
str1_ptr.get_type().into(),
|
||||
str1_len.get_type().into(),
|
||||
str2_ptr.get_type().into(),
|
||||
str2_len.get_type().into(),
|
||||
],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
)
|
||||
});
|
||||
let result = ctx
|
||||
.builder
|
||||
.build_call(
|
||||
func,
|
||||
&[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
|
||||
"str_eq_call",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
generator.bool_to_i1(ctx, result)
|
||||
}
|
I think this case of pointer equality is rare enough that it does not warrant its special handling and corresponding IR bloat. LLVM is already sluggish enough processing all that IR.
It's also potentially incorrect since you could call the function with the same pointer and different lengths and it would return true.
So remove this line.