forked from M-Labs/nac3
Implement string equality and inequality comparisons in 'expr.rs', Implement nac3_str_eq in demo.c and Utilise nac3_str_eq for string comparisons.
This commit is contained in:
parent
4211858273
commit
887f24093f
@ -1897,17 +1897,39 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
debug_assert_eq!(comparators.len(), ops.len());
|
debug_assert_eq!(comparators.len(), ops.len());
|
||||||
|
|
||||||
|
// Handle NDArray comparisons first
|
||||||
if comparators.len() == 1 {
|
if comparators.len() == 1 {
|
||||||
let left_ty = ctx.unifier.get_representative(left.0.unwrap());
|
let left_ty = match left.0 {
|
||||||
let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap());
|
Some(ref ty) => ctx.unifier.get_representative(*ty),
|
||||||
|
None => return Err("Left type is None".to_string()),
|
||||||
|
};
|
||||||
|
let right_ty = match comparators[0].0 {
|
||||||
|
Some(ref ty) => ctx.unifier.get_representative(*ty),
|
||||||
|
None => return Err("Right type is None".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
|
// Safely unwrap the left and right operands
|
||||||
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
|
let (left_ty_opt, lhs) = left;
|
||||||
|
let left_ty = match left_ty_opt {
|
||||||
|
Some(ty) => ctx.unifier.get_representative(ty),
|
||||||
|
None => return Err("Left type is None".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (right_ty_opt, rhs) = match comparators.first().copied() {
|
||||||
|
Some((Some(ty), val)) => (Some(ty), val),
|
||||||
|
Some((None, _)) | None => {
|
||||||
|
return Err("Comparator type is None".to_string());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let right_ty = match right_ty_opt {
|
||||||
|
Some(ty) => ctx.unifier.get_representative(ty),
|
||||||
|
None => return Err("Right type is None".to_string()),
|
||||||
|
};
|
||||||
let op = ops[0];
|
let op = ops[0];
|
||||||
|
|
||||||
let is_ndarray1 =
|
let is_ndarray1 =
|
||||||
@ -1992,6 +2014,106 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Safely unwrap the left and first comparator operands
|
||||||
|
let (left_ty_opt, lhs_val) = left;
|
||||||
|
let left_ty = match left_ty_opt {
|
||||||
|
Some(ty) => ctx.unifier.get_representative(ty),
|
||||||
|
None => return Err("Left type is None".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (right_ty_opt, rhs_val) = match comparators.first().copied() {
|
||||||
|
Some((Some(ty), val)) => (Some(ty), val),
|
||||||
|
Some((None, _)) | None => {
|
||||||
|
return Err("Comparator type is None".to_string());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let right_ty = match right_ty_opt {
|
||||||
|
Some(ty) => ctx.unifier.get_representative(ty),
|
||||||
|
None => return Err("Right type is None".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle string comparisons
|
||||||
|
if ctx.unifier.unioned(left_ty, ctx.primitives.str)
|
||||||
|
&& ctx.unifier.unioned(right_ty, ctx.primitives.str)
|
||||||
|
{
|
||||||
|
// Only handle == and != for strings here
|
||||||
|
if ops.len() == 1 && (ops[0] == ast::Cmpop::Eq || ops[0] == ast::Cmpop::NotEq) {
|
||||||
|
// Extract string data
|
||||||
|
let lhs_struct = lhs_val.into_struct_value();
|
||||||
|
let lhs_ptr = match ctx.builder.build_extract_value(lhs_struct, 0, "lhs_ptr") {
|
||||||
|
Ok(val) => val.into_pointer_value(),
|
||||||
|
Err(e) => return Err(format!("Failed to extract lhs_ptr: {e:?}")),
|
||||||
|
};
|
||||||
|
let lhs_len = match ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len") {
|
||||||
|
Ok(val) => val.into_int_value(),
|
||||||
|
Err(e) => return Err(format!("Failed to extract lhs_len: {e:?}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
let rhs_struct = rhs_val.into_struct_value();
|
||||||
|
let rhs_ptr = match ctx.builder.build_extract_value(rhs_struct, 0, "rhs_ptr") {
|
||||||
|
Ok(val) => val.into_pointer_value(),
|
||||||
|
Err(e) => return Err(format!("Failed to extract rhs_ptr: {e:?}")),
|
||||||
|
};
|
||||||
|
let rhs_len = match ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len") {
|
||||||
|
Ok(val) => val.into_int_value(),
|
||||||
|
Err(e) => return Err(format!("Failed to extract rhs_len: {e:?}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get or declare nac3_str_eq function
|
||||||
|
let str_eq_fn = if let Some(fun) = ctx.module.get_function("nac3_str_eq") {
|
||||||
|
fun
|
||||||
|
} else {
|
||||||
|
let bool_type = ctx.ctx.bool_type();
|
||||||
|
let i8_ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||||
|
let usize_type = generator.get_size_type(ctx.ctx);
|
||||||
|
let fn_type = bool_type.fn_type(
|
||||||
|
&[i8_ptr_type.into(), usize_type.into(), i8_ptr_type.into(), usize_type.into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
ctx.module.add_function("nac3_str_eq", fn_type, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Call nac3_str_eq(lhs_ptr, lhs_len, rhs_ptr, rhs_len)
|
||||||
|
let call_site = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(
|
||||||
|
str_eq_fn,
|
||||||
|
&[lhs_ptr.into(), lhs_len.into(), rhs_ptr.into(), rhs_len.into()],
|
||||||
|
"str_eq_call",
|
||||||
|
)
|
||||||
|
.expect("Failed to build call to nac3_str_eq");
|
||||||
|
|
||||||
|
// The function returns a bool (i1 in LLVM)
|
||||||
|
let eq_result = match call_site.try_as_basic_value() {
|
||||||
|
Either::Left(inkwell::values::BasicValueEnum::IntValue(val)) => val,
|
||||||
|
Either::Left(_) | Either::Right(_) => {
|
||||||
|
return Err("nac3_str_eq did not return an i1".to_string())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert i1 to i8 if NAC3 bool is i8
|
||||||
|
let eq_i8 = match ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8")
|
||||||
|
{
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(e) => return Err(format!("Failed to extend i1 to i8: {e:?}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
// If the operation is NotEq, invert the result
|
||||||
|
let final_result = if ops[0] == ast::Cmpop::NotEq {
|
||||||
|
match ctx.builder.build_not(eq_i8, "neq") {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(e) => return Err(format!("Failed to invert eq_i8 for NotEq: {e:?}")),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eq_i8
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return as ValueEnum::Dynamic
|
||||||
|
return Ok(Some(ValueEnum::Dynamic(final_result.into())));
|
||||||
|
}
|
||||||
|
return Err(format!("Operator '{:?}' not supported for strings", ops[0]));
|
||||||
|
}
|
||||||
|
|
||||||
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
||||||
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
||||||
let (left_ty, lhs) = lhs;
|
let (left_ty, lhs) = lhs;
|
||||||
|
@ -10,7 +10,7 @@ use inkwell::{
|
|||||||
OptimizationLevel,
|
OptimizationLevel,
|
||||||
};
|
};
|
||||||
use nac3parser::{
|
use nac3parser::{
|
||||||
ast::{fold::Fold, FileName, StrRef, Cmpop},
|
ast::{fold::Fold, FileName, StrRef},
|
||||||
parser::parse_program,
|
parser::parse_program,
|
||||||
};
|
};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
@ -474,24 +474,3 @@ fn test_classes_ndarray_type_new() {
|
|||||||
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
|
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
|
||||||
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
|
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_string_equality() {
|
|
||||||
use crate::symbol_resolver::SymbolValue;
|
|
||||||
|
|
||||||
let str1 = SymbolValue::Str("hello".to_string());
|
|
||||||
let str2 = SymbolValue::Str("hello".to_string());
|
|
||||||
let str3 = SymbolValue::Str("world".to_string());
|
|
||||||
|
|
||||||
// Test equality (==)
|
|
||||||
let result_eq = str1.evaluate_cmp_op(&Cmpop::Eq, &str2).unwrap();
|
|
||||||
assert_eq!(result_eq, SymbolValue::Bool(true));
|
|
||||||
|
|
||||||
// Test inequality (!=) with different strings
|
|
||||||
let result_neq_true = str1.evaluate_cmp_op(&Cmpop::NotEq, &str3).unwrap();
|
|
||||||
assert_eq!(result_neq_true, SymbolValue::Bool(true));
|
|
||||||
|
|
||||||
// Test inequality (!=) with identical strings
|
|
||||||
let result_neq_false = str1.evaluate_cmp_op(&Cmpop::NotEq, &str2).unwrap();
|
|
||||||
assert_eq!(result_neq_false, SymbolValue::Bool(false));
|
|
||||||
}
|
|
@ -9,7 +9,7 @@ use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, Struct
|
|||||||
use itertools::{chain, izip, Itertools};
|
use itertools::{chain, izip, Itertools};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
|
|
||||||
use nac3parser::ast::{Constant, Expr, Location, StrRef, Cmpop};
|
use nac3parser::ast::{Constant, Expr, Location, StrRef};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{CodeGenContext, CodeGenerator},
|
codegen::{CodeGenContext, CodeGenerator},
|
||||||
@ -147,26 +147,26 @@ impl SymbolValue {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn evaluate_cmp_op(
|
// pub fn evaluate_cmp_op(
|
||||||
&self,
|
// &self,
|
||||||
op: &Cmpop,
|
// op: &Cmpop,
|
||||||
other: &SymbolValue,
|
// other: &SymbolValue,
|
||||||
) -> Result<SymbolValue, String> {
|
// ) -> Result<SymbolValue, String> {
|
||||||
match (self, other, op) {
|
// match (self, other, op) {
|
||||||
// Integer comparisons
|
// // // Integer comparisons
|
||||||
(SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)),
|
// // (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)),
|
||||||
(SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)),
|
// // (SymbolValue::I32(a), SymbolValue::I32(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)),
|
||||||
// String comparisons
|
// // String comparisons
|
||||||
(SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)),
|
// (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::Eq) => Ok(SymbolValue::Bool(a == b)),
|
||||||
(SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)),
|
// (SymbolValue::Str(a), SymbolValue::Str(b), Cmpop::NotEq) => Ok(SymbolValue::Bool(a != b)),
|
||||||
// Add other types and comparison operators as needed
|
// // Add other types and comparison operators as needed
|
||||||
_ => Err(format!(
|
// _ => Err(format!(
|
||||||
"Unsupported comparison operation for {:?} and {:?} with operator {:?}",
|
// "Unsupported comparison operation for {:?} and {:?} with operator {:?}",
|
||||||
self, other, op
|
// self, other, op
|
||||||
)),
|
// )),
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
/// Returns the [`Type`] representing the data type of this value.
|
/// Returns the [`Type`] representing the data type of this value.
|
||||||
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
|
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
|
||||||
match self {
|
match self {
|
||||||
|
@ -748,7 +748,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||||
|
|
||||||
/* str ========= */
|
/* str ========= */
|
||||||
impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
|
impl_eq(unifier, store, str_t, &[str_t], Some(bool_t));
|
||||||
|
|
||||||
/* list ======== */
|
/* list ======== */
|
||||||
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
||||||
|
@ -105,6 +105,14 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
|
|||||||
__builtin_unreachable();
|
__builtin_unreachable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compare two strings by content and length.
|
||||||
|
bool nac3_str_eq(const char* lhs, size_t lhs_len, const char* rhs, size_t rhs_len) {
|
||||||
|
if (lhs_len != rhs_len) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return memcmp(lhs, rhs, lhs_len) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
// See `struct Exception<'a>` in
|
// See `struct Exception<'a>` in
|
||||||
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
|
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
|
||||||
struct Exception {
|
struct Exception {
|
||||||
|
@ -4,24 +4,90 @@ def output_bool(x: bool):
|
|||||||
|
|
||||||
|
|
||||||
def str_eq():
|
def str_eq():
|
||||||
|
# Basic cases
|
||||||
output_bool("" == "")
|
output_bool("" == "")
|
||||||
output_bool("a" == "")
|
output_bool("a" == "")
|
||||||
output_bool("a" == "b")
|
output_bool("a" == "b")
|
||||||
output_bool("b" == "a")
|
output_bool("b" == "a")
|
||||||
output_bool("a" == "a")
|
output_bool("a" == "a")
|
||||||
|
|
||||||
|
# Longer identical strings
|
||||||
output_bool("test string" == "test string")
|
output_bool("test string" == "test string")
|
||||||
|
output_bool("Lorem ipsum dolor sit amet" == "Lorem ipsum dolor sit amet")
|
||||||
|
|
||||||
|
# Different by one character
|
||||||
output_bool("test string1" == "test string2")
|
output_bool("test string1" == "test string2")
|
||||||
|
|
||||||
|
# Numeric strings
|
||||||
|
output_bool("123" == "123")
|
||||||
|
output_bool("123" == "321")
|
||||||
|
|
||||||
|
# Different lengths
|
||||||
|
output_bool("abc" == "abcde")
|
||||||
|
|
||||||
|
# Case sensitivity
|
||||||
|
output_bool("Hello, World!" == "Hello, World!")
|
||||||
|
output_bool("CaseSensitive" == "casesensitive")
|
||||||
|
|
||||||
|
# Leading and trailing spaces
|
||||||
|
output_bool(" leading space" == "leading space")
|
||||||
|
output_bool("trailing space " == "trailing space")
|
||||||
|
output_bool(" " == " ")
|
||||||
|
|
||||||
|
# Special characters and punctuation
|
||||||
|
output_bool("special@#%$^&*()_+{}|:<>?`~chars" == "special@#%$^&*()_+{}|:<>?`~chars")
|
||||||
|
|
||||||
|
# Unicode strings
|
||||||
|
output_bool("café" == "café") # Same accented character
|
||||||
|
output_bool("café" == "cafe") # Accented vs unaccented
|
||||||
|
|
||||||
|
# Strings with newline and tab
|
||||||
|
output_bool("line1\nline2" == "line1\nline2")
|
||||||
|
output_bool("tab\tseparated" == "tab\tseparated")
|
||||||
|
output_bool("line1\nline2" == "line1 line2")
|
||||||
|
|
||||||
|
|
||||||
def str_ne():
|
def str_ne():
|
||||||
|
# Basic cases
|
||||||
output_bool("" != "")
|
output_bool("" != "")
|
||||||
output_bool("a" != "")
|
output_bool("a" != "")
|
||||||
output_bool("a" != "b")
|
output_bool("a" != "b")
|
||||||
output_bool("b" != "a")
|
output_bool("b" != "a")
|
||||||
output_bool("a" != "a")
|
output_bool("a" != "a")
|
||||||
|
|
||||||
|
# Longer identical strings
|
||||||
output_bool("test string" != "test string")
|
output_bool("test string" != "test string")
|
||||||
|
|
||||||
|
# Different by one character
|
||||||
output_bool("test string1" != "test string2")
|
output_bool("test string1" != "test string2")
|
||||||
|
|
||||||
|
# Numeric strings
|
||||||
|
output_bool("123" != "123")
|
||||||
|
output_bool("123" != "321")
|
||||||
|
|
||||||
|
# Different lengths
|
||||||
|
output_bool("abc" != "abcde")
|
||||||
|
|
||||||
|
# Case sensitivity
|
||||||
|
output_bool("Hello, World!" != "Hello, World!")
|
||||||
|
output_bool("CaseSensitive" != "casesensitive")
|
||||||
|
|
||||||
|
# Leading and trailing spaces
|
||||||
|
output_bool(" leading space" != "leading space")
|
||||||
|
output_bool("trailing space " != "trailing space")
|
||||||
|
output_bool(" " != " ")
|
||||||
|
|
||||||
|
# Special characters and punctuation
|
||||||
|
output_bool("special@#%$^&*()_+{}|:<>?`~chars" != "special@#%$^&*()_+{}|:<>?`~chars")
|
||||||
|
|
||||||
|
# Unicode strings
|
||||||
|
output_bool("café" != "café")
|
||||||
|
output_bool("café" != "cafe")
|
||||||
|
|
||||||
|
# Strings with newline and tab
|
||||||
|
output_bool("line1\nline2" != "line1\nline2")
|
||||||
|
output_bool("tab\tseparated" != "tab\tseparated")
|
||||||
|
output_bool("line1\nline2" != "line1 line2")
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
str_eq()
|
str_eq()
|
||||||
|
Loading…
Reference in New Issue
Block a user