forked from M-Labs/nac3
Implement string equality comparison in irrt with relevant test cases
This commit is contained in:
parent
c6e6b7bc95
commit
f2880dce03
@ -3,3 +3,4 @@
|
|||||||
#include "irrt/math.hpp"
|
#include "irrt/math.hpp"
|
||||||
#include "irrt/ndarray.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
|
#include "irrt/string.hpp"
|
22
nac3core/irrt/irrt/string.hpp
Normal file
22
nac3core/irrt/irrt/string.hpp
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "irrt/int_types.hpp"
|
||||||
|
namespace {
|
||||||
|
template<typename SizeT>
|
||||||
|
int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
|
||||||
|
if (str1 == str2) return 1;
|
||||||
|
if (len1 != len2) return 0;
|
||||||
|
for (SizeT i = 0; i < len1; ++i) {
|
||||||
|
if (static_cast<unsigned char>(str1[i]) != static_cast<unsigned char>(str2[i])) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
|
||||||
|
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
|
||||||
|
}
|
||||||
|
}
|
@ -1898,33 +1898,16 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
debug_assert_eq!(comparators.len(), ops.len());
|
debug_assert_eq!(comparators.len(), ops.len());
|
||||||
|
|
||||||
if comparators.len() == 1 {
|
if comparators.len() == 1 {
|
||||||
let (Some(left_ty), _) = left else { codegen_unreachable!(ctx) };
|
let left_ty = ctx.unifier.get_representative(left.0.unwrap());
|
||||||
let left_ty = ctx.unifier.get_representative(left_ty);
|
let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap());
|
||||||
|
|
||||||
let (Some(right_ty), _) = comparators[0] else { codegen_unreachable!(ctx) };
|
|
||||||
let right_ty = ctx.unifier.get_representative(right_ty);
|
|
||||||
|
|
||||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (left_ty_opt, lhs) = left;
|
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
|
||||||
let left_ty = match left_ty_opt {
|
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
|
||||||
Some(ty) => ctx.unifier.get_representative(ty),
|
|
||||||
None => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (right_ty_opt, rhs) = match comparators.first().copied() {
|
|
||||||
Some((Some(ty), val)) => (Some(ty), val),
|
|
||||||
Some((None, _)) | None => {
|
|
||||||
codegen_unreachable!(ctx);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let right_ty = match right_ty_opt {
|
|
||||||
Some(ty) => ctx.unifier.get_representative(ty),
|
|
||||||
None => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
let op = ops[0];
|
let op = ops[0];
|
||||||
|
|
||||||
let is_ndarray1 =
|
let is_ndarray1 =
|
||||||
@ -2009,77 +1992,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (Some(left_ty), lhs_val) = left else { codegen_unreachable!(ctx) };
|
|
||||||
let left_ty = ctx.unifier.get_representative(left_ty);
|
|
||||||
|
|
||||||
let (Some(right_ty), rhs_val) = comparators.first().copied().unwrap() else {
|
|
||||||
codegen_unreachable!(ctx)
|
|
||||||
};
|
|
||||||
let right_ty = ctx.unifier.get_representative(right_ty);
|
|
||||||
|
|
||||||
if ctx.unifier.unioned(left_ty, ctx.primitives.str)
|
|
||||||
&& ctx.unifier.unioned(right_ty, ctx.primitives.str)
|
|
||||||
{
|
|
||||||
if ops.len() == 1 && (ops[0] == ast::Cmpop::Eq || ops[0] == ast::Cmpop::NotEq) {
|
|
||||||
let lhs_struct = lhs_val.into_struct_value();
|
|
||||||
let lhs_ptr = ctx
|
|
||||||
.builder
|
|
||||||
.build_extract_value(lhs_struct, 0, "lhs_ptr")
|
|
||||||
.unwrap()
|
|
||||||
.into_pointer_value();
|
|
||||||
let lhs_len =
|
|
||||||
ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len").unwrap().into_int_value();
|
|
||||||
|
|
||||||
let rhs_struct = rhs_val.into_struct_value();
|
|
||||||
let rhs_ptr = ctx
|
|
||||||
.builder
|
|
||||||
.build_extract_value(rhs_struct, 0, "rhs_ptr")
|
|
||||||
.unwrap()
|
|
||||||
.into_pointer_value();
|
|
||||||
let rhs_len =
|
|
||||||
ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len").unwrap().into_int_value();
|
|
||||||
|
|
||||||
let str_eq_fn = if let Some(fun) = ctx.module.get_function("nac3_str_eq") {
|
|
||||||
fun
|
|
||||||
} else {
|
|
||||||
let bool_type = ctx.ctx.bool_type();
|
|
||||||
let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
|
||||||
let usize_type = generator.get_size_type(ctx.ctx);
|
|
||||||
let fn_type = bool_type.fn_type(
|
|
||||||
&[i8_ptr_type.into(), usize_type.into(), i8_ptr_type.into(), usize_type.into()],
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
ctx.module.add_function("nac3_str_eq", fn_type, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let call_site = ctx
|
|
||||||
.builder
|
|
||||||
.build_call(
|
|
||||||
str_eq_fn,
|
|
||||||
&[lhs_ptr.into(), lhs_len.into(), rhs_ptr.into(), rhs_len.into()],
|
|
||||||
"str_eq_call",
|
|
||||||
)
|
|
||||||
.expect("Failed to build call to nac3_str_eq");
|
|
||||||
|
|
||||||
let eq_result = match call_site.try_as_basic_value() {
|
|
||||||
Either::Left(inkwell::values::BasicValueEnum::IntValue(val)) => val,
|
|
||||||
Either::Left(_) | Either::Right(_) => codegen_unreachable!(ctx),
|
|
||||||
};
|
|
||||||
|
|
||||||
let eq_i8 =
|
|
||||||
ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8").unwrap();
|
|
||||||
|
|
||||||
let final_result = if ops[0] == ast::Cmpop::NotEq {
|
|
||||||
ctx.builder.build_not(eq_i8, "neq").unwrap()
|
|
||||||
} else {
|
|
||||||
eq_i8
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(Some(ValueEnum::Dynamic(final_result.into())));
|
|
||||||
}
|
|
||||||
codegen_unreachable!(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
||||||
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
||||||
let (left_ty, lhs) = lhs;
|
let (left_ty, lhs) = lhs;
|
||||||
@ -2314,7 +2226,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
|
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
|
||||||
codegen_unreachable!(ctx, "Only __eq__ and __ne__ supported for this type")
|
todo!("Only __eq__ and __ne__ is implemented for lists")
|
||||||
}
|
}
|
||||||
|
|
||||||
let left_val =
|
let left_val =
|
||||||
@ -2438,10 +2350,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
gen_list_cmpop(generator, ctx)?
|
gen_list_cmpop(generator, ctx)?
|
||||||
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) {
|
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) {
|
||||||
let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else {
|
let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else {
|
||||||
codegen_unreachable!(ctx)
|
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
|
||||||
};
|
};
|
||||||
let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else {
|
let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else {
|
||||||
codegen_unreachable!(ctx)
|
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
|
||||||
};
|
};
|
||||||
|
|
||||||
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
|
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
|
||||||
@ -2566,7 +2478,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
ctx.ctx.bool_type().get_poison()
|
ctx.ctx.bool_type().get_poison()
|
||||||
} else {
|
} else {
|
||||||
codegen_unreachable!(ctx)
|
return Err(format!("'{}' not supported between instances of '{}' and '{}'",
|
||||||
|
op.op_info().symbol,
|
||||||
|
ctx.unifier.stringify(left_ty),
|
||||||
|
ctx.unifier.stringify(right_ty)))
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||||
|
@ -20,6 +20,7 @@ mod list;
|
|||||||
mod math;
|
mod math;
|
||||||
mod ndarray;
|
mod ndarray;
|
||||||
mod slice;
|
mod slice;
|
||||||
|
mod string;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||||
|
45
nac3core/src/codegen/irrt/string.rs
Normal file
45
nac3core/src/codegen/irrt/string.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use inkwell::{
|
||||||
|
values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Either;
|
||||||
|
|
||||||
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
||||||
|
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
str1_ptr: PointerValue<'ctx>,
|
||||||
|
str1_len: IntValue<'ctx>,
|
||||||
|
str2_ptr: PointerValue<'ctx>,
|
||||||
|
str2_len: IntValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| {
|
||||||
|
let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let i64_type = ctx.ctx.i64_type();
|
||||||
|
let i32_type = ctx.ctx.i32_type();
|
||||||
|
let fn_type = i32_type.fn_type(
|
||||||
|
&[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
ctx.module.add_function("nac3_str_eq", fn_type, None)
|
||||||
|
});
|
||||||
|
let result = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(
|
||||||
|
string_eq_fn,
|
||||||
|
&[
|
||||||
|
str1_ptr.into(),
|
||||||
|
ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(),
|
||||||
|
str2_ptr.into(),
|
||||||
|
ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(),
|
||||||
|
],
|
||||||
|
"string_eq",
|
||||||
|
)
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap();
|
||||||
|
generator.bool_to_i1(ctx, result)
|
||||||
|
}
|
@ -748,7 +748,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||||
|
|
||||||
/* str ========= */
|
/* str ========= */
|
||||||
impl_eq(unifier, store, str_t, &[str_t], Some(bool_t));
|
impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
|
||||||
|
|
||||||
/* list ======== */
|
/* list ======== */
|
||||||
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
||||||
|
@ -105,14 +105,6 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
|
|||||||
__builtin_unreachable();
|
__builtin_unreachable();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compare two strings by content and length.
|
|
||||||
bool nac3_str_eq(const char* lhs, size_t lhs_len, const char* rhs, size_t rhs_len) {
|
|
||||||
if (lhs_len != rhs_len) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return memcmp(lhs, rhs, lhs_len) == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// See `struct Exception<'a>` in
|
// See `struct Exception<'a>` in
|
||||||
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
|
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
|
||||||
struct Exception {
|
struct Exception {
|
||||||
@ -143,4 +135,4 @@ extern int32_t run(void);
|
|||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
run();
|
run();
|
||||||
}
|
}
|
@ -2,62 +2,37 @@
|
|||||||
def output_bool(x: bool):
|
def output_bool(x: bool):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def test_str_eq():
|
||||||
def str_eq():
|
|
||||||
# Basic cases
|
|
||||||
output_bool("" == "")
|
output_bool("" == "")
|
||||||
output_bool("a" == "")
|
output_bool("a" == "")
|
||||||
output_bool("a" == "b")
|
|
||||||
output_bool("b" == "a")
|
|
||||||
output_bool("a" == "a")
|
output_bool("a" == "a")
|
||||||
|
output_bool("a" == "b")
|
||||||
# Longer identical strings
|
|
||||||
output_bool("test string" == "test string")
|
output_bool("test string" == "test string")
|
||||||
output_bool("Lorem ipsum dolor sit amet" == "Lorem ipsum dolor sit amet")
|
output_bool("Lorem ipsum dolor sit amet" == "Lorem ipsum dolor sit amet")
|
||||||
|
output_bool("test1" == "test2")
|
||||||
# Different by one character
|
|
||||||
output_bool("test string1" == "test string2")
|
|
||||||
|
|
||||||
# Numeric strings
|
|
||||||
output_bool("123" == "123")
|
output_bool("123" == "123")
|
||||||
output_bool("123" == "321")
|
output_bool("123" == "321")
|
||||||
|
|
||||||
# Different lengths
|
|
||||||
output_bool("abc" == "abcde")
|
output_bool("abc" == "abcde")
|
||||||
|
output_bool("a" == "aa")
|
||||||
|
output_bool(" " == " ")
|
||||||
|
output_bool(" a " == " a ")
|
||||||
|
|
||||||
# Leading and trailing spaces
|
def test_str_ne():
|
||||||
output_bool(" leading space" == "leading space")
|
|
||||||
output_bool("trailing space " == "trailing space")
|
|
||||||
output_bool(" " == " ")
|
|
||||||
|
|
||||||
def str_ne():
|
|
||||||
# Basic cases
|
|
||||||
output_bool("" != "")
|
output_bool("" != "")
|
||||||
output_bool("a" != "")
|
output_bool("a" != "")
|
||||||
output_bool("a" != "b")
|
|
||||||
output_bool("b" != "a")
|
|
||||||
output_bool("a" != "a")
|
output_bool("a" != "a")
|
||||||
|
output_bool("a" != "b")
|
||||||
# Longer identical strings
|
|
||||||
output_bool("test string" != "test string")
|
output_bool("test string" != "test string")
|
||||||
|
output_bool("Lorem ipsum dolor sit amet" != "Lorem ipsum dolor sit amet")
|
||||||
# Different by one character
|
output_bool("test1" != "test2")
|
||||||
output_bool("test string1" != "test string2")
|
|
||||||
|
|
||||||
# Numeric strings
|
|
||||||
output_bool("123" != "123")
|
output_bool("123" != "123")
|
||||||
output_bool("123" != "321")
|
output_bool("123" != "321")
|
||||||
|
|
||||||
# Different lengths
|
|
||||||
output_bool("abc" != "abcde")
|
output_bool("abc" != "abcde")
|
||||||
|
output_bool("a" != "aa")
|
||||||
# Leading and trailing spaces
|
output_bool(" " != " ")
|
||||||
output_bool(" leading space" != "leading space")
|
output_bool(" a " != " a ")
|
||||||
output_bool("trailing space " != "trailing space")
|
|
||||||
output_bool(" " != " ")
|
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
str_eq()
|
test_str_eq()
|
||||||
str_ne()
|
test_str_ne()
|
||||||
|
return 0
|
||||||
return 0
|
|
Loading…
Reference in New Issue
Block a user