Implement string equality operator using IRRT and optimise LLVM implementation
This commit is contained in:
parent
ad67a99c8f
commit
e13d753329
@ -3,3 +3,4 @@
|
|||||||
#include "irrt/math.hpp"
|
#include "irrt/math.hpp"
|
||||||
#include "irrt/ndarray.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
|
#include "irrt/string.hpp"
|
22
nac3core/irrt/irrt/string.hpp
Normal file
22
nac3core/irrt/irrt/string.hpp
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
namespace {
|
||||||
|
template<typename SizeT>
|
||||||
|
int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
|
||||||
|
if (str1 == str2) return 1;
|
||||||
|
if (len1 != len2) return 0;
|
||||||
|
for (SizeT i = 0; i < len1; ++i) {
|
||||||
|
if (static_cast<unsigned char>(str1[i]) != static_cast<unsigned char>(str2[i])) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
|
||||||
|
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
|
||||||
|
}
|
||||||
|
}
|
@ -24,7 +24,7 @@ use super::{
|
|||||||
irrt::*,
|
irrt::*,
|
||||||
llvm_intrinsics::{
|
llvm_intrinsics::{
|
||||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||||
call_int_umin, call_memcpy_generic,
|
call_memcpy_generic,
|
||||||
},
|
},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
need_sret, numpy,
|
need_sret, numpy,
|
||||||
@ -2072,111 +2072,42 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
} else if left_ty == ctx.primitives.str {
|
} else if left_ty == ctx.primitives.str {
|
||||||
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
||||||
|
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let lhs = lhs.into_struct_value();
|
let lhs = lhs.into_struct_value();
|
||||||
let rhs = rhs.into_struct_value();
|
let rhs = rhs.into_struct_value();
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||||
ctx.builder.build_store(plhs, lhs).unwrap();
|
ctx.builder.build_store(plhs, lhs).unwrap();
|
||||||
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||||
ctx.builder.build_store(prhs, rhs).unwrap();
|
ctx.builder.build_store(prhs, rhs).unwrap();
|
||||||
|
|
||||||
|
let lhs_ptr = ctx.build_in_bounds_gep_and_load(
|
||||||
|
plhs,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_pointer_value();
|
||||||
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
||||||
plhs,
|
plhs,
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||||
None,
|
None,
|
||||||
).into_int_value();
|
).into_int_value();
|
||||||
|
|
||||||
|
let rhs_ptr = ctx.build_in_bounds_gep_and_load(
|
||||||
|
prhs,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_pointer_value();
|
||||||
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
||||||
prhs,
|
prhs,
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||||
None,
|
None,
|
||||||
).into_int_value();
|
).into_int_value();
|
||||||
|
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
|
||||||
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
|
|
||||||
|
|
||||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
|
|
||||||
|
|
||||||
ctx.builder.position_at_end(post_foreach_cmp);
|
|
||||||
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
|
|
||||||
ctx.builder.position_at_end(current_bb);
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(len, false),
|
|
||||||
|generator, ctx, _, i| {
|
|
||||||
let lhs_char = {
|
|
||||||
let plhs_data = ctx.build_in_bounds_gep_and_load(
|
|
||||||
plhs,
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
||||||
None,
|
|
||||||
).into_pointer_value();
|
|
||||||
|
|
||||||
ctx.build_in_bounds_gep_and_load(
|
|
||||||
plhs_data,
|
|
||||||
&[i],
|
|
||||||
None
|
|
||||||
).into_int_value()
|
|
||||||
};
|
|
||||||
let rhs_char = {
|
|
||||||
let prhs_data = ctx.build_in_bounds_gep_and_load(
|
|
||||||
prhs,
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
||||||
None,
|
|
||||||
).into_pointer_value();
|
|
||||||
|
|
||||||
ctx.build_in_bounds_gep_and_load(
|
|
||||||
prhs_data,
|
|
||||||
&[i],
|
|
||||||
None
|
|
||||||
).into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
gen_if_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
|
|
||||||
},
|
|
||||||
|_, ctx| {
|
|
||||||
let bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
|
|
||||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
|_, _| Ok(()),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let bb = ctx.builder.get_insert_block().unwrap();
|
|
||||||
let is_len_eq = ctx.builder.build_int_compare(
|
|
||||||
IntPredicate::EQ,
|
|
||||||
lhs_len,
|
|
||||||
rhs_len,
|
|
||||||
"",
|
|
||||||
).unwrap();
|
|
||||||
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
|
|
||||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
|
||||||
|
|
||||||
ctx.builder.position_at_end(post_foreach_cmp);
|
|
||||||
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
|
|
||||||
|
|
||||||
// Invert the final value if __ne__
|
|
||||||
if *op == Cmpop::NotEq {
|
if *op == Cmpop::NotEq {
|
||||||
ctx.builder.build_not(cmp_phi, "").unwrap()
|
ctx.builder.build_not(result, "").unwrap()
|
||||||
} else {
|
} else {
|
||||||
cmp_phi
|
result
|
||||||
}
|
}
|
||||||
} else if [left_ty, right_ty]
|
} else if [left_ty, right_ty]
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -15,11 +15,13 @@ pub use list::*;
|
|||||||
pub use math::*;
|
pub use math::*;
|
||||||
pub use ndarray::*;
|
pub use ndarray::*;
|
||||||
pub use slice::*;
|
pub use slice::*;
|
||||||
|
pub use string::*;
|
||||||
|
|
||||||
mod list;
|
mod list;
|
||||||
mod math;
|
mod math;
|
||||||
mod ndarray;
|
mod ndarray;
|
||||||
mod slice;
|
mod slice;
|
||||||
|
mod string;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||||
|
45
nac3core/src/codegen/irrt/string.rs
Normal file
45
nac3core/src/codegen/irrt/string.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
||||||
|
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
str1_ptr: PointerValue<'ctx>,
|
||||||
|
str1_len: IntValue<'ctx>,
|
||||||
|
str2_ptr: PointerValue<'ctx>,
|
||||||
|
str2_len: IntValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| {
|
||||||
|
let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let i64_type = ctx.ctx.i64_type();
|
||||||
|
let i32_type = ctx.ctx.i32_type();
|
||||||
|
let fn_type = i32_type.fn_type(
|
||||||
|
&[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
ctx.module.add_function("nac3_str_eq", fn_type, None)
|
||||||
|
});
|
||||||
|
let result = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(
|
||||||
|
string_eq_fn,
|
||||||
|
&[
|
||||||
|
str1_ptr.into(),
|
||||||
|
ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(),
|
||||||
|
str2_ptr.into(),
|
||||||
|
ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(),
|
||||||
|
],
|
||||||
|
"string_eq",
|
||||||
|
)
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
generator.bool_to_i1(ctx, result)
|
||||||
|
}
|
@ -11,6 +11,9 @@ def str_eq():
|
|||||||
output_bool("a" == "a")
|
output_bool("a" == "a")
|
||||||
output_bool("test string" == "test string")
|
output_bool("test string" == "test string")
|
||||||
output_bool("test string1" == "test string2")
|
output_bool("test string1" == "test string2")
|
||||||
|
output_bool("test" == "testing")
|
||||||
|
output_bool("abcd" == "abdc")
|
||||||
|
output_bool(" " == " ")
|
||||||
|
|
||||||
|
|
||||||
def str_ne():
|
def str_ne():
|
||||||
@ -21,10 +24,13 @@ def str_ne():
|
|||||||
output_bool("a" != "a")
|
output_bool("a" != "a")
|
||||||
output_bool("test string" != "test string")
|
output_bool("test string" != "test string")
|
||||||
output_bool("test string1" != "test string2")
|
output_bool("test string1" != "test string2")
|
||||||
|
output_bool("test" != "testing")
|
||||||
|
output_bool("abcd" != "abdc")
|
||||||
|
output_bool(" " != " ")
|
||||||
|
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
str_eq()
|
str_eq()
|
||||||
str_ne()
|
str_ne()
|
||||||
|
|
||||||
return 0
|
return 0
|
Loading…
Reference in New Issue
Block a user