Implement string equality operator using IRRT and optimise LLVM implementation #561
@ -3,3 +3,4 @@
|
|||||||
#include "irrt/math.hpp"
|
#include "irrt/math.hpp"
|
||||||
#include "irrt/ndarray.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
|
#include "irrt/string.hpp"
|
22
nac3core/irrt/irrt/string.hpp
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
namespace {
|
||||||
|
template<typename SizeT>
|
||||||
|
int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
|
||||||
|
if (str1 == str2) return 1;
|
||||||
ramtej marked this conversation as resolved
Outdated
|
|||||||
|
if (len1 != len2) return 0;
|
||||||
|
for (SizeT i = 0; i < len1; ++i) {
|
||||||
|
if (static_cast<unsigned char>(str1[i]) != static_cast<unsigned char>(str2[i])) {
|
||||||
|
return 0;
|
||||||
sb10q
commented
Is IRRT really required at this point? Is IRRT really required at this point?
derppening
commented
Yes, Yes, `__builtin_strncmp` isn't exposed in LLVM IR so we will need `clang` to emit the equivalent bitcode for us.
sb10q
commented
Isn't clang also just a LLVM user? Isn't clang also just a LLVM user?
How does it do it?
sb10q
commented
Also, NAC3 is just using regular Pascal-style strings, isn't it? Also, NAC3 is just using regular Pascal-style strings, isn't it?
If so, then the correct function to use is ``memcmp``. ``strncmp`` is intrinsically slower as it adds a check for NUL characters which is unnecessary and may even lead to incorrect behavior.
|
|||||||
|
}
|
||||||
ramtej marked this conversation as resolved
Outdated
sb10q
commented
Isn't that just memcmp but implemented inefficiently? Isn't that just memcmp but implemented inefficiently?
derppening
commented
LLVM has an intrinsic LLVM has an intrinsic `__builtin_memcmp` for this exact purpose: https://clang.llvm.org/docs/LanguageExtensions.html#string-builtins
derppening
commented
In fact, LLVM has In fact, LLVM has `__builtin_strncmp`. If that's the case, the entire IRRT can just be implemented in terms of that intrinsic function.
|
|||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
|
||||||
ramtej marked this conversation as resolved
Outdated
sb10q
commented
uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening uint64_t-only length looks highly suspicious to me. I thought we had cleaned up all this size_t business already. @derppening
derppening
commented
`str` is implemented as `struct { i8*, usize }` where `usize` is `size_t`. You will need two C functions for this, one with 32-bit and one with 64-bit lengths.
|
|||||||
|
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
|
||||||
|
}
|
||||||
|
}
|
@ -24,7 +24,7 @@ use super::{
|
|||||||
irrt::*,
|
irrt::*,
|
||||||
llvm_intrinsics::{
|
llvm_intrinsics::{
|
||||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||||
call_int_umin, call_memcpy_generic,
|
call_memcpy_generic,
|
||||||
},
|
},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
need_sret, numpy,
|
need_sret, numpy,
|
||||||
@ -2072,111 +2072,42 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
} else if left_ty == ctx.primitives.str {
|
} else if left_ty == ctx.primitives.str {
|
||||||
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
||||||
|
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let lhs = lhs.into_struct_value();
|
let lhs = lhs.into_struct_value();
|
||||||
let rhs = rhs.into_struct_value();
|
let rhs = rhs.into_struct_value();
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||||
ctx.builder.build_store(plhs, lhs).unwrap();
|
ctx.builder.build_store(plhs, lhs).unwrap();
|
||||||
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||||
ctx.builder.build_store(prhs, rhs).unwrap();
|
ctx.builder.build_store(prhs, rhs).unwrap();
|
||||||
|
|
||||||
|
let lhs_ptr = ctx.build_in_bounds_gep_and_load(
|
||||||
|
plhs,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
sb10q
commented
i32? i32?
ramtej
commented
Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction: "The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant." Yes, according to https://releases.llvm.org/14.0.0/docs/LangRef.html#getelementptr-instruction:
"The type of each index argument depends on the type it is indexing into. When indexing into a (optionally packed) structure, only i32 integer constants are allowed (when using a vector of indices they must all be the same i32 integer constant). When indexing into an array, pointer or vector, integers of any width are allowed, and they are not required to be constant. These integers are treated as signed values where relevant."
ramtej
commented
so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format so since we are using string as a structure, i have to use i32. i tried using llvm_usize, but i kept running into segmentation error because of this GEP format
derppening
commented
What about What about `llvm_usize` for the first `gep` and `i32` for the second?
ramtej
commented
yea, that worked too, would it be better for me to use that implementation then? yea, that worked too, would it be better for me to use that implementation then?
derppening
commented
Yes, please use that implementation. Yes, please use that implementation.
|
|||||||
|
None,
|
||||||
|
).into_pointer_value();
|
||||||
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
||||||
plhs,
|
plhs,
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||||
None,
|
None,
|
||||||
).into_int_value();
|
).into_int_value();
|
||||||
|
|
||||||
|
let rhs_ptr = ctx.build_in_bounds_gep_and_load(
|
||||||
|
prhs,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_pointer_value();
|
||||||
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
||||||
prhs,
|
prhs,
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||||
None,
|
None,
|
||||||
).into_int_value();
|
).into_int_value();
|
||||||
|
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
|
||||||
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
|
|
||||||
|
|
||||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
|
|
||||||
|
|
||||||
ctx.builder.position_at_end(post_foreach_cmp);
|
|
||||||
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
|
|
||||||
ctx.builder.position_at_end(current_bb);
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(len, false),
|
|
||||||
|generator, ctx, _, i| {
|
|
||||||
let lhs_char = {
|
|
||||||
let plhs_data = ctx.build_in_bounds_gep_and_load(
|
|
||||||
plhs,
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
||||||
None,
|
|
||||||
).into_pointer_value();
|
|
||||||
|
|
||||||
ctx.build_in_bounds_gep_and_load(
|
|
||||||
plhs_data,
|
|
||||||
&[i],
|
|
||||||
None
|
|
||||||
).into_int_value()
|
|
||||||
};
|
|
||||||
let rhs_char = {
|
|
||||||
let prhs_data = ctx.build_in_bounds_gep_and_load(
|
|
||||||
prhs,
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
||||||
None,
|
|
||||||
).into_pointer_value();
|
|
||||||
|
|
||||||
ctx.build_in_bounds_gep_and_load(
|
|
||||||
prhs_data,
|
|
||||||
&[i],
|
|
||||||
None
|
|
||||||
).into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
gen_if_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
|
|
||||||
},
|
|
||||||
|_, ctx| {
|
|
||||||
let bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
|
|
||||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
|_, _| Ok(()),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
let is_len_eq = ctx.builder.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
lhs_len,
|
|
||||||
rhs_len,
|
|
||||||
"",
|
|
||||||
).unwrap();
|
|
||||||
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
|
|
||||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
|
||||||
|
|
||||||
ctx.builder.position_at_end(post_foreach_cmp);
|
|
||||||
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
|
|
||||||
|
|
||||||
// Invert the final value if __ne__
|
|
||||||
if *op == Cmpop::NotEq {
|
if *op == Cmpop::NotEq {
|
||||||
ctx.builder.build_not(cmp_phi, "").unwrap()
|
ctx.builder.build_not(result, "").unwrap()
|
||||||
} else {
|
} else {
|
||||||
cmp_phi
|
result
|
||||||
}
|
}
|
||||||
} else if [left_ty, right_ty]
|
} else if [left_ty, right_ty]
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -15,11 +15,13 @@ pub use list::*;
|
|||||||
pub use math::*;
|
pub use math::*;
|
||||||
pub use ndarray::*;
|
pub use ndarray::*;
|
||||||
pub use slice::*;
|
pub use slice::*;
|
||||||
|
pub use string::*;
|
||||||
|
|
||||||
mod list;
|
mod list;
|
||||||
mod math;
|
mod math;
|
||||||
mod ndarray;
|
mod ndarray;
|
||||||
mod slice;
|
mod slice;
|
||||||
|
mod string;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||||
|
45
nac3core/src/codegen/irrt/string.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
||||||
|
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
str1_ptr: PointerValue<'ctx>,
|
||||||
|
str1_len: IntValue<'ctx>,
|
||||||
|
str2_ptr: PointerValue<'ctx>,
|
||||||
|
str2_len: IntValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
ramtej marked this conversation as resolved
Outdated
sb10q
commented
And when is And when is ``nac3_str_eq64`` called?
|
|||||||
|
let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| {
|
||||||
|
let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let i64_type = ctx.ctx.i64_type();
|
||||||
|
let i32_type = ctx.ctx.i32_type();
|
||||||
|
let fn_type = i32_type.fn_type(
|
||||||
|
&[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
ctx.module.add_function("nac3_str_eq", fn_type, None)
|
||||||
|
});
|
||||||
|
let result = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(
|
||||||
|
string_eq_fn,
|
||||||
|
&[
|
||||||
|
str1_ptr.into(),
|
||||||
|
ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(),
|
||||||
|
str2_ptr.into(),
|
||||||
|
ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(),
|
||||||
|
],
|
||||||
|
"string_eq",
|
||||||
|
)
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
generator.bool_to_i1(ctx, result)
|
||||||
|
}
|
@ -11,6 +11,9 @@ def str_eq():
|
|||||||
output_bool("a" == "a")
|
output_bool("a" == "a")
|
||||||
output_bool("test string" == "test string")
|
output_bool("test string" == "test string")
|
||||||
output_bool("test string1" == "test string2")
|
output_bool("test string1" == "test string2")
|
||||||
|
output_bool("test" == "testing")
|
||||||
|
output_bool("abcd" == "abdc")
|
||||||
|
output_bool(" " == " ")
|
||||||
|
|
||||||
|
|
||||||
def str_ne():
|
def str_ne():
|
||||||
@ -21,6 +24,9 @@ def str_ne():
|
|||||||
output_bool("a" != "a")
|
output_bool("a" != "a")
|
||||||
output_bool("test string" != "test string")
|
output_bool("test string" != "test string")
|
||||||
output_bool("test string1" != "test string2")
|
output_bool("test string1" != "test string2")
|
||||||
|
output_bool("test" != "testing")
|
||||||
|
output_bool("abcd" != "abdc")
|
||||||
|
output_bool(" " != " ")
|
||||||
|
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
|
I think this case of pointer equality is rare enough that it does not warrant its special handling and corresponding IR bloat. LLVM is already sluggish enough processing all that IR.
It's also potentially incorrect since you could call the function with the same pointer and different lengths and it would return true.
So remove this line.