diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index f8462ab2..cd0544b6 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1897,17 +1897,39 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ) -> Result>, String> { debug_assert_eq!(comparators.len(), ops.len()); + // Handle NDArray comparisons first if comparators.len() == 1 { - let left_ty = ctx.unifier.get_representative(left.0.unwrap()); - let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap()); + let left_ty = match left.0 { + Some(ref ty) => ctx.unifier.get_representative(*ty), + None => return Err("Left type is None".to_string()), + }; + let right_ty = match comparators[0].0 { + Some(ref ty) => ctx.unifier.get_representative(*ty), + None => return Err("Right type is None".to_string()), + }; if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); - let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; - let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; + // Safely unwrap the left and right operands + let (left_ty_opt, lhs) = left; + let left_ty = match left_ty_opt { + Some(ty) => ctx.unifier.get_representative(ty), + None => return Err("Left type is None".to_string()), + }; + + let (right_ty_opt, rhs) = match comparators.first().copied() { + Some((Some(ty), val)) => (Some(ty), val), + Some((None, _)) | None => { + return Err("Comparator type is None".to_string()); + } + }; + let right_ty = match right_ty_opt { + Some(ty) => ctx.unifier.get_representative(ty), + None => return Err("Right type is None".to_string()), + }; let op = ops[0]; let is_ndarray1 = @@ -1992,6 +2014,106 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } } + // Safely unwrap the left and first comparator operands + let (left_ty_opt, lhs_val) = left; + let left_ty = match left_ty_opt { + Some(ty) => ctx.unifier.get_representative(ty), + None => return Err("Left type is None".to_string()), + }; + + let (right_ty_opt, rhs_val) = match comparators.first().copied() { + Some((Some(ty), val)) => (Some(ty), val), + Some((None, _)) | None => { + return Err("Comparator type is None".to_string()); + } + }; + let right_ty = match right_ty_opt { + Some(ty) => ctx.unifier.get_representative(ty), + None => return Err("Right type is None".to_string()), + }; + + // Handle string comparisons + if ctx.unifier.unioned(left_ty, ctx.primitives.str) + && ctx.unifier.unioned(right_ty, ctx.primitives.str) + { + // Only handle == and != for strings here + if ops.len() == 1 && (ops[0] == ast::Cmpop::Eq || ops[0] == ast::Cmpop::NotEq) { + // Extract string data + let lhs_struct = lhs_val.into_struct_value(); + let lhs_ptr = match ctx.builder.build_extract_value(lhs_struct, 0, "lhs_ptr") { + Ok(val) => val.into_pointer_value(), + Err(e) => return Err(format!("Failed to extract lhs_ptr: {e:?}")), + }; + let lhs_len = match ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len") { + Ok(val) => val.into_int_value(), + Err(e) => return Err(format!("Failed to extract lhs_len: {e:?}")), + }; + + let rhs_struct = rhs_val.into_struct_value(); + let rhs_ptr = match ctx.builder.build_extract_value(rhs_struct, 0, "rhs_ptr") { + Ok(val) => val.into_pointer_value(), + Err(e) => return Err(format!("Failed to extract rhs_ptr: {e:?}")), + }; + let rhs_len = match ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len") { + Ok(val) => val.into_int_value(), + Err(e) => return Err(format!("Failed to extract rhs_len: {e:?}")), + }; + + // Get or declare nac3_str_eq function + let str_eq_fn = if let Some(fun) = ctx.module.get_function("nac3_str_eq") { + fun + } else { + let bool_type = ctx.ctx.bool_type(); + let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let usize_type = generator.get_size_type(ctx.ctx); + let fn_type = bool_type.fn_type( + &[i8_ptr_type.into(), usize_type.into(), i8_ptr_type.into(), usize_type.into()], + false, + ); + ctx.module.add_function("nac3_str_eq", fn_type, None) + }; + + // Call nac3_str_eq(lhs_ptr, lhs_len, rhs_ptr, rhs_len) + let call_site = ctx + .builder + .build_call( + str_eq_fn, + &[lhs_ptr.into(), lhs_len.into(), rhs_ptr.into(), rhs_len.into()], + "str_eq_call", + ) + .expect("Failed to build call to nac3_str_eq"); + + // The function returns a bool (i1 in LLVM) + let eq_result = match call_site.try_as_basic_value() { + Either::Left(inkwell::values::BasicValueEnum::IntValue(val)) => val, + Either::Left(_) | Either::Right(_) => { + return Err("nac3_str_eq did not return an i1".to_string()) + } + }; + + // Convert i1 to i8 if NAC3 bool is i8 + let eq_i8 = match ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8") + { + Ok(val) => val, + Err(e) => return Err(format!("Failed to extend i1 to i8: {e:?}")), + }; + + // If the operation is NotEq, invert the result + let final_result = if ops[0] == ast::Cmpop::NotEq { + match ctx.builder.build_not(eq_i8, "neq") { + Ok(val) => val, + Err(e) => return Err(format!("Failed to invert eq_i8 for NotEq: {e:?}")), + } + } else { + eq_i8 + }; + + // Return as ValueEnum::Dynamic + return Ok(Some(ValueEnum::Dynamic(final_result.into()))); + } + return Err(format!("Operator '{:?}' not supported for strings", ops[0])); + } + let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let (left_ty, lhs) = lhs; diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 73349340..a1c391a7 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -10,7 +10,7 @@ use inkwell::{ OptimizationLevel, }; use nac3parser::{ - ast::{fold::Fold, FileName, StrRef, Cmpop}, + ast::{fold::Fold, FileName, StrRef}, parser::parse_program, }; use parking_lot::RwLock; @@ -474,24 +474,3 @@ fn test_classes_ndarray_type_new() { let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into()); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } - -#[test] -fn test_string_equality() { - use crate::symbol_resolver::SymbolValue; - - let str1 = SymbolValue::Str("hello".to_string()); - let str2 = SymbolValue::Str("hello".to_string()); - let str3 = SymbolValue::Str("world".to_string()); - - // Test equality (==) - let result_eq = str1.evaluate_cmp_op(&Cmpop::Eq, &str2).unwrap(); - assert_eq!(result_eq, SymbolValue::Bool(true)); - - // Test inequality (!=) with different strings - let result_neq_true = str1.evaluate_cmp_op(&Cmpop::NotEq, &str3).unwrap(); - assert_eq!(result_neq_true, SymbolValue::Bool(true)); - - // Test inequality (!=) with identical strings - let result_neq_false = str1.evaluate_cmp_op(&Cmpop::NotEq, &str2).unwrap(); - assert_eq!(result_neq_false, SymbolValue::Bool(false)); -} \ No newline at end of file diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index f075ec48..541b9f7c 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -9,7 +9,7 @@ use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, Struct use itertools::{chain, izip, Itertools}; use parking_lot::RwLock; -use nac3parser::ast::{Constant, Expr, Location, StrRef, Cmpop}; +use nac3parser::ast::{Constant, Expr, Location, StrRef}; use crate::{ codegen::{CodeGenContext, CodeGenerator}, @@ -147,26 +147,26 @@ impl SymbolValue { } } - pub fn evaluate_cmp_op( - &self, - op: &Cmpop, - other: &SymbolValue, - ) -> Result { - match (self, other, op) { - // Integer comparisons - (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)), - (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)), - // String comparisons - (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)), - (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)), - // Add other types and comparison operators as needed - _ => Err(format!( - "Unsupported comparison operation for {:?} and {:?} with operator {:?}", - self, other, op - )), - } - } - + // pub fn evaluate_cmp_op( + // &self, + // op: &Cmpop, + // other: &SymbolValue, + // ) -> Result { + // match (self, other, op) { + // // // Integer comparisons + // // (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)), + // // (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)), + // // String comparisons + // (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)), + // (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)), + // // Add other types and comparison operators as needed + // _ => Err(format!( + // "Unsupported comparison operation for {:?} and {:?} with operator {:?}", + // self, other, op + // )), + // } + // } + /// Returns the [`Type`] representing the data type of this value. pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { match self { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 60972f03..7287c742 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -748,7 +748,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); /* str ========= */ - impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t)); + impl_eq(unifier, store, str_t, &[str_t], Some(bool_t)); /* list ======== */ impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]); diff --git a/nac3standalone/demo/demo.c b/nac3standalone/demo/demo.c index 202a8bf6..c3ba5545 100644 --- a/nac3standalone/demo/demo.c +++ b/nac3standalone/demo/demo.c @@ -105,6 +105,14 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t __builtin_unreachable(); } +// Compare two strings by content and length. +bool nac3_str_eq(const char* lhs, size_t lhs_len, const char* rhs, size_t rhs_len) { + if (lhs_len != rhs_len) { + return false; + } + return memcmp(lhs, rhs, lhs_len) == 0; +} + // See `struct Exception<'a>` in // https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs struct Exception { diff --git a/nac3standalone/demo/src/str.py b/nac3standalone/demo/src/str.py index d28f05e2..756f1d88 100644 --- a/nac3standalone/demo/src/str.py +++ b/nac3standalone/demo/src/str.py @@ -4,24 +4,90 @@ def output_bool(x: bool): def str_eq(): + # Basic cases output_bool("" == "") output_bool("a" == "") output_bool("a" == "b") output_bool("b" == "a") output_bool("a" == "a") + + # Longer identical strings output_bool("test string" == "test string") + output_bool("Lorem ipsum dolor sit amet" == "Lorem ipsum dolor sit amet") + + # Different by one character output_bool("test string1" == "test string2") + # Numeric strings + output_bool("123" == "123") + output_bool("123" == "321") + + # Different lengths + output_bool("abc" == "abcde") + + # Case sensitivity + output_bool("Hello, World!" == "Hello, World!") + output_bool("CaseSensitive" == "casesensitive") + + # Leading and trailing spaces + output_bool(" leading space" == "leading space") + output_bool("trailing space " == "trailing space") + output_bool(" " == " ") + + # Special characters and punctuation + output_bool("special@#%$^&*()_+{}|:<>?`~chars" == "special@#%$^&*()_+{}|:<>?`~chars") + + # Unicode strings + output_bool("café" == "café") # Same accented character + output_bool("café" == "cafe") # Accented vs unaccented + + # Strings with newline and tab + output_bool("line1\nline2" == "line1\nline2") + output_bool("tab\tseparated" == "tab\tseparated") + output_bool("line1\nline2" == "line1 line2") + def str_ne(): + # Basic cases output_bool("" != "") output_bool("a" != "") output_bool("a" != "b") output_bool("b" != "a") output_bool("a" != "a") + + # Longer identical strings output_bool("test string" != "test string") + + # Different by one character output_bool("test string1" != "test string2") + # Numeric strings + output_bool("123" != "123") + output_bool("123" != "321") + + # Different lengths + output_bool("abc" != "abcde") + + # Case sensitivity + output_bool("Hello, World!" != "Hello, World!") + output_bool("CaseSensitive" != "casesensitive") + + # Leading and trailing spaces + output_bool(" leading space" != "leading space") + output_bool("trailing space " != "trailing space") + output_bool(" " != " ") + + # Special characters and punctuation + output_bool("special@#%$^&*()_+{}|:<>?`~chars" != "special@#%$^&*()_+{}|:<>?`~chars") + + # Unicode strings + output_bool("café" != "café") + output_bool("café" != "cafe") + + # Strings with newline and tab + output_bool("line1\nline2" != "line1\nline2") + output_bool("tab\tseparated" != "tab\tseparated") + output_bool("line1\nline2" != "line1 line2") def run() -> int32: str_eq()