1
0
forked from M-Labs/nac3

Implement string equality operator in symbol_resolver and update tests

This commit is contained in:
ram 2024-12-03 18:17:58 +00:00
parent 399af54043
commit 4211858273
2 changed files with 22 additions and 27 deletions

View File

@ -10,7 +10,7 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use nac3parser::{ use nac3parser::{
ast::{fold::Fold, FileName, StrRef}, ast::{fold::Fold, FileName, StrRef, Cmpop},
parser::parse_program, parser::parse_program,
}; };
use parking_lot::RwLock; use parking_lot::RwLock;
@ -476,32 +476,22 @@ fn test_classes_ndarray_type_new() {
} }
#[test] #[test]
fn test_string_equality(){ fn test_string_equality() {
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::typedef::{PrimitiveStore, Unifier};
use crate::magic_methods::{Binop, Operator};
use nac3parser::ast::Cmpop;
let primitives = PrimitiveStore::default();
let mut unifier = Unifier::default();
let str1 = SymbolValue::Str("hello".to_string()); let str1 = SymbolValue::Str("hello".to_string());
let str2 = SymbolValue::Str("hello".to_string()); let str2 = SymbolValue::Str("hello".to_string());
let str3 = SymbolValue::Str("world".to_string()); let str3 = SymbolValue::Str("world".to_string());
// Create binary operators for equality and inequality
let eq_op = Binop::normal(Operator::Eq);
let neq_op = Binop::normal(Operator::NotEq);
// Test equality (==) // Test equality (==)
let result_eq = str1.evaluate_binary_op(&eq_op, &str2, &primitives, &mut unifier).unwrap(); let result_eq = str1.evaluate_cmp_op(&Cmpop::Eq, &str2).unwrap();
assert_eq!(result_eq, SymbolValue::Bool(true)); assert_eq!(result_eq, SymbolValue::Bool(true));
// Test inequality (!=) with different strings // Test inequality (!=) with different strings
let result_neq_true = str1.evaluate_binary_op(&neq_op, &str3, &primitives, &mut unifier).unwrap(); let result_neq_true = str1.evaluate_cmp_op(&Cmpop::NotEq, &str3).unwrap();
assert_eq!(result_neq_true, SymbolValue::Bool(true)); assert_eq!(result_neq_true, SymbolValue::Bool(true));
// Test inequality (!=) with identical strings // Test inequality (!=) with identical strings
let result_neq_false = str1.evaluate_binary_op(&neq_op, &str2, &primitives, &mut unifier).unwrap(); let result_neq_false = str1.evaluate_cmp_op(&Cmpop::NotEq, &str2).unwrap();
assert_eq!(result_neq_false, SymbolValue::Bool(false)); assert_eq!(result_neq_false, SymbolValue::Bool(false));
} }

View File

@ -9,7 +9,7 @@ use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, Struct
use itertools::{chain, izip, Itertools}; use itertools::{chain, izip, Itertools};
use parking_lot::RwLock; use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef}; use nac3parser::ast::{Constant, Expr, Location, StrRef, Cmpop};
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{CodeGenContext, CodeGenerator},
@ -147,21 +147,26 @@ impl SymbolValue {
} }
} }
/// Evaluate binary operations pub fn evaluate_cmp_op(
pub fn evaluate_binary_op(
&self, &self,
op: &Cmpop,
other: &SymbolValue, other: &SymbolValue,
op: fn(i64, i64) -> i64,
) -> Result<SymbolValue, String> { ) -> Result<SymbolValue, String> {
match (self, other) { match (self, other, op) {
(SymbolValue::I32(a), SymbolValue::I32(b)) => Ok(SymbolValue::I32(op(*a as i64, *b as i64) as i32)), // Integer comparisons
(SymbolValue::I64(a), SymbolValue::I64(b)) => Ok(SymbolValue::I64(op(*a, *b))), (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)),
(SymbolValue::U32(a), SymbolValue::U32(b)) => Ok(SymbolValue::U32(op(*a as i64, *b as i64) as u32)), (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)),
(SymbolValue::U64(a), SymbolValue::U64(b)) => Ok(SymbolValue::U64(op(*a as i64, *b as i64) as u64)), // String comparisons
_ => Err(format!("Unsupported binary operation for {self} and {other}")), (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)),
(SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)),
// Add other types and comparison operators as needed
_ => Err(format!(
"Unsupported comparison operation for {:?} and {:?} with operator {:?}",
self, other, op
)),
} }
} }
/// Returns the [`Type`] representing the data type of this value. /// Returns the [`Type`] representing the data type of this value.
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
match self { match self {