diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 574eeef1..73349340 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -10,7 +10,7 @@ use inkwell::{ OptimizationLevel, }; use nac3parser::{ - ast::{fold::Fold, FileName, StrRef}, + ast::{fold::Fold, FileName, StrRef, Cmpop}, parser::parse_program, }; use parking_lot::RwLock; @@ -476,32 +476,22 @@ fn test_classes_ndarray_type_new() { } #[test] -fn test_string_equality(){ +fn test_string_equality() { use crate::symbol_resolver::SymbolValue; - use crate::typedef::{PrimitiveStore, Unifier}; - use crate::magic_methods::{Binop, Operator}; - use nac3parser::ast::Cmpop; - - let primitives = PrimitiveStore::default(); - let mut unifier = Unifier::default(); let str1 = SymbolValue::Str("hello".to_string()); let str2 = SymbolValue::Str("hello".to_string()); let str3 = SymbolValue::Str("world".to_string()); - // Create binary operators for equality and inequality - let eq_op = Binop::normal(Operator::Eq); - let neq_op = Binop::normal(Operator::NotEq); - // Test equality (==) - let result_eq = str1.evaluate_binary_op(&eq_op, &str2, &primitives, &mut unifier).unwrap(); + let result_eq = str1.evaluate_cmp_op(&Cmpop::Eq, &str2).unwrap(); assert_eq!(result_eq, SymbolValue::Bool(true)); // Test inequality (!=) with different strings - let result_neq_true = str1.evaluate_binary_op(&neq_op, &str3, &primitives, &mut unifier).unwrap(); + let result_neq_true = str1.evaluate_cmp_op(&Cmpop::NotEq, &str3).unwrap(); assert_eq!(result_neq_true, SymbolValue::Bool(true)); // Test inequality (!=) with identical strings - let result_neq_false = str1.evaluate_binary_op(&neq_op, &str2, &primitives, &mut unifier).unwrap(); + let result_neq_false = str1.evaluate_cmp_op(&Cmpop::NotEq, &str2).unwrap(); assert_eq!(result_neq_false, SymbolValue::Bool(false)); } \ No newline at end of file diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 82043744..f075ec48 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -9,7 +9,7 @@ use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, Struct use itertools::{chain, izip, Itertools}; use parking_lot::RwLock; -use nac3parser::ast::{Constant, Expr, Location, StrRef}; +use nac3parser::ast::{Constant, Expr, Location, StrRef, Cmpop}; use crate::{ codegen::{CodeGenContext, CodeGenerator}, @@ -147,21 +147,26 @@ impl SymbolValue { } } - /// Evaluate binary operations - pub fn evaluate_binary_op( + pub fn evaluate_cmp_op( &self, + op: &Cmpop, other: &SymbolValue, - op: fn(i64, i64) -> i64, ) -> Result { - match (self, other) { - (SymbolValue::I32(a), SymbolValue::I32(b)) => Ok(SymbolValue::I32(op(*a as i64, *b as i64) as i32)), - (SymbolValue::I64(a), SymbolValue::I64(b)) => Ok(SymbolValue::I64(op(*a, *b))), - (SymbolValue::U32(a), SymbolValue::U32(b)) => Ok(SymbolValue::U32(op(*a as i64, *b as i64) as u32)), - (SymbolValue::U64(a), SymbolValue::U64(b)) => Ok(SymbolValue::U64(op(*a as i64, *b as i64) as u64)), - _ => Err(format!("Unsupported binary operation for {self} and {other}")), + match (self, other, op) { + // Integer comparisons + (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)), + (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)), + // String comparisons + (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)), + (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)), + // Add other types and comparison operators as needed + _ => Err(format!( + "Unsupported comparison operation for {:?} and {:?} with operator {:?}", + self, other, op + )), } - } - + } + /// Returns the [`Type`] representing the data type of this value. pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { match self {