From f2880dce03c036971737afeff3bb65aaa0157a3e Mon Sep 17 00:00:00 2001 From: ram Date: Tue, 10 Dec 2024 14:43:16 +0000 Subject: [PATCH] Implement string equality comparison in irrt with relevant test cases --- nac3core/irrt/irrt.cpp | 1 + nac3core/irrt/irrt/string.hpp | 22 +++++ nac3core/src/codegen/expr.rs | 107 +++--------------------- nac3core/src/codegen/irrt/mod.rs | 1 + nac3core/src/codegen/irrt/string.rs | 45 ++++++++++ nac3core/src/typecheck/magic_methods.rs | 2 +- nac3standalone/demo/demo.c | 10 +-- nac3standalone/demo/src/str.py | 59 ++++--------- 8 files changed, 99 insertions(+), 148 deletions(-) create mode 100644 nac3core/irrt/irrt/string.hpp create mode 100644 nac3core/src/codegen/irrt/string.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 7966322..05d07ad 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -3,3 +3,4 @@ #include "irrt/math.hpp" #include "irrt/ndarray.hpp" #include "irrt/slice.hpp" +#include "irrt/string.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/string.hpp b/nac3core/irrt/irrt/string.hpp new file mode 100644 index 0000000..5fc3571 --- /dev/null +++ b/nac3core/irrt/irrt/string.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "irrt/int_types.hpp" +namespace { +template +int32_t __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) { + if (str1 == str2) return 1; + if (len1 != len2) return 0; + for (SizeT i = 0; i < len1; ++i) { + if (static_cast(str1[i]) != static_cast(str2[i])) { + return 0; + } + } + return 1; +} +} // namespace + +extern "C" { +int32_t nac3_str_eq(const char* str1, uint64_t len1, const char* str2, uint64_t len2) { + return __nac3_str_eq_impl(str1, len1, str2, len2); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 632619e..f8462ab 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1898,33 +1898,16 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( debug_assert_eq!(comparators.len(), ops.len()); if comparators.len() == 1 { - let (Some(left_ty), _) = left else { codegen_unreachable!(ctx) }; - let left_ty = ctx.unifier.get_representative(left_ty); - - let (Some(right_ty), _) = comparators[0] else { codegen_unreachable!(ctx) }; - let right_ty = ctx.unifier.get_representative(right_ty); + let left_ty = ctx.unifier.get_representative(left.0.unwrap()); + let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap()); if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); - let (left_ty_opt, lhs) = left; - let left_ty = match left_ty_opt { - Some(ty) => ctx.unifier.get_representative(ty), - None => codegen_unreachable!(ctx), - }; - - let (right_ty_opt, rhs) = match comparators.first().copied() { - Some((Some(ty), val)) => (Some(ty), val), - Some((None, _)) | None => { - codegen_unreachable!(ctx); - } - }; - let right_ty = match right_ty_opt { - Some(ty) => ctx.unifier.get_representative(ty), - None => codegen_unreachable!(ctx), - }; + let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; + let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; let op = ops[0]; let is_ndarray1 = @@ -2009,77 +1992,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } } - let (Some(left_ty), lhs_val) = left else { codegen_unreachable!(ctx) }; - let left_ty = ctx.unifier.get_representative(left_ty); - - let (Some(right_ty), rhs_val) = comparators.first().copied().unwrap() else { - codegen_unreachable!(ctx) - }; - let right_ty = ctx.unifier.get_representative(right_ty); - - if ctx.unifier.unioned(left_ty, ctx.primitives.str) - && ctx.unifier.unioned(right_ty, ctx.primitives.str) - { - if ops.len() == 1 && (ops[0] == ast::Cmpop::Eq || ops[0] == ast::Cmpop::NotEq) { - let lhs_struct = lhs_val.into_struct_value(); - let lhs_ptr = ctx - .builder - .build_extract_value(lhs_struct, 0, "lhs_ptr") - .unwrap() - .into_pointer_value(); - let lhs_len = - ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len").unwrap().into_int_value(); - - let rhs_struct = rhs_val.into_struct_value(); - let rhs_ptr = ctx - .builder - .build_extract_value(rhs_struct, 0, "rhs_ptr") - .unwrap() - .into_pointer_value(); - let rhs_len = - ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len").unwrap().into_int_value(); - - let str_eq_fn = if let Some(fun) = ctx.module.get_function("nac3_str_eq") { - fun - } else { - let bool_type = ctx.ctx.bool_type(); - let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let usize_type = generator.get_size_type(ctx.ctx); - let fn_type = bool_type.fn_type( - &[i8_ptr_type.into(), usize_type.into(), i8_ptr_type.into(), usize_type.into()], - false, - ); - ctx.module.add_function("nac3_str_eq", fn_type, None) - }; - - let call_site = ctx - .builder - .build_call( - str_eq_fn, - &[lhs_ptr.into(), lhs_len.into(), rhs_ptr.into(), rhs_len.into()], - "str_eq_call", - ) - .expect("Failed to build call to nac3_str_eq"); - - let eq_result = match call_site.try_as_basic_value() { - Either::Left(inkwell::values::BasicValueEnum::IntValue(val)) => val, - Either::Left(_) | Either::Right(_) => codegen_unreachable!(ctx), - }; - - let eq_i8 = - ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8").unwrap(); - - let final_result = if ops[0] == ast::Cmpop::NotEq { - ctx.builder.build_not(eq_i8, "neq").unwrap() - } else { - eq_i8 - }; - - return Ok(Some(ValueEnum::Dynamic(final_result.into()))); - } - codegen_unreachable!(ctx); - } - let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let (left_ty, lhs) = lhs; @@ -2314,7 +2226,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { - codegen_unreachable!(ctx, "Only __eq__ and __ne__ supported for this type") + todo!("Only __eq__ and __ne__ is implemented for lists") } let left_val = @@ -2438,10 +2350,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( gen_list_cmpop(generator, ctx)? } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) { let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else { - codegen_unreachable!(ctx) + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) }; let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else { - codegen_unreachable!(ctx) + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) }; if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { @@ -2566,7 +2478,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().get_poison() } else { - codegen_unreachable!(ctx) + return Err(format!("'{}' not supported between instances of '{}' and '{}'", + op.op_info().symbol, + ctx.unifier.stringify(left_ty), + ctx.unifier.stringify(right_ty))) }; Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index f6c4a1e..5d45d69 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -20,6 +20,7 @@ mod list; mod math; mod ndarray; mod slice; +mod string; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { diff --git a/nac3core/src/codegen/irrt/string.rs b/nac3core/src/codegen/irrt/string.rs new file mode 100644 index 0000000..0a9d3a1 --- /dev/null +++ b/nac3core/src/codegen/irrt/string.rs @@ -0,0 +1,45 @@ +use inkwell::{ + values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Either; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. +pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + str1_ptr: PointerValue<'ctx>, + str1_len: IntValue<'ctx>, + str2_ptr: PointerValue<'ctx>, + str2_len: IntValue<'ctx>, +) -> IntValue<'ctx> { + let string_eq_fn = ctx.module.get_function("nac3_str_eq").unwrap_or_else(|| { + let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let i64_type = ctx.ctx.i64_type(); + let i32_type = ctx.ctx.i32_type(); + let fn_type = i32_type.fn_type( + &[i8_ptr_type.into(), i64_type.into(), i8_ptr_type.into(), i64_type.into()], + false, + ); + ctx.module.add_function("nac3_str_eq", fn_type, None) + }); + let result = ctx + .builder + .build_call( + string_eq_fn, + &[ + str1_ptr.into(), + ctx.builder.build_int_z_extend(str1_len, ctx.ctx.i64_type(), "").unwrap().into(), + str2_ptr.into(), + ctx.builder.build_int_z_extend(str2_len, ctx.ctx.i64_type(), "").unwrap().into(), + ], + "string_eq", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + generator.bool_to_i1(ctx, result) +} diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 7287c74..60972f0 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -748,7 +748,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); /* str ========= */ - impl_eq(unifier, store, str_t, &[str_t], Some(bool_t)); + impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t)); /* list ======== */ impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]); diff --git a/nac3standalone/demo/demo.c b/nac3standalone/demo/demo.c index c3ba554..bcd8f28 100644 --- a/nac3standalone/demo/demo.c +++ b/nac3standalone/demo/demo.c @@ -105,14 +105,6 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t __builtin_unreachable(); } -// Compare two strings by content and length. -bool nac3_str_eq(const char* lhs, size_t lhs_len, const char* rhs, size_t rhs_len) { - if (lhs_len != rhs_len) { - return false; - } - return memcmp(lhs, rhs, lhs_len) == 0; -} - // See `struct Exception<'a>` in // https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs struct Exception { @@ -143,4 +135,4 @@ extern int32_t run(void); int main(void) { run(); -} +} \ No newline at end of file diff --git a/nac3standalone/demo/src/str.py b/nac3standalone/demo/src/str.py index a9d1dcb..5c834d9 100644 --- a/nac3standalone/demo/src/str.py +++ b/nac3standalone/demo/src/str.py @@ -2,62 +2,37 @@ def output_bool(x: bool): ... - -def str_eq(): - # Basic cases +def test_str_eq(): output_bool("" == "") output_bool("a" == "") - output_bool("a" == "b") - output_bool("b" == "a") output_bool("a" == "a") - - # Longer identical strings + output_bool("a" == "b") output_bool("test string" == "test string") output_bool("Lorem ipsum dolor sit amet" == "Lorem ipsum dolor sit amet") - - # Different by one character - output_bool("test string1" == "test string2") - - # Numeric strings + output_bool("test1" == "test2") output_bool("123" == "123") output_bool("123" == "321") - - # Different lengths output_bool("abc" == "abcde") + output_bool("a" == "aa") + output_bool(" " == " ") + output_bool(" a " == " a ") - # Leading and trailing spaces - output_bool(" leading space" == "leading space") - output_bool("trailing space " == "trailing space") - output_bool(" " == " ") - -def str_ne(): - # Basic cases +def test_str_ne(): output_bool("" != "") output_bool("a" != "") - output_bool("a" != "b") - output_bool("b" != "a") output_bool("a" != "a") - - # Longer identical strings + output_bool("a" != "b") output_bool("test string" != "test string") - - # Different by one character - output_bool("test string1" != "test string2") - - # Numeric strings + output_bool("Lorem ipsum dolor sit amet" != "Lorem ipsum dolor sit amet") + output_bool("test1" != "test2") output_bool("123" != "123") output_bool("123" != "321") - - # Different lengths output_bool("abc" != "abcde") - - # Leading and trailing spaces - output_bool(" leading space" != "leading space") - output_bool("trailing space " != "trailing space") - output_bool(" " != " ") - + output_bool("a" != "aa") + output_bool(" " != " ") + output_bool(" a " != " a ") + def run() -> int32: - str_eq() - str_ne() - - return 0 + test_str_eq() + test_str_ne() + return 0 \ No newline at end of file