1
0
forked from M-Labs/nac3

Update error messages and remove redundant test cases

This commit is contained in:
ram 2024-12-06 08:41:07 +00:00
parent 914e3ba096
commit c6e6b7bc95
2 changed files with 36 additions and 106 deletions

View File

@ -1897,38 +1897,33 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
debug_assert_eq!(comparators.len(), ops.len()); debug_assert_eq!(comparators.len(), ops.len());
// Handle NDArray comparisons first
if comparators.len() == 1 { if comparators.len() == 1 {
let left_ty = match left.0 { let (Some(left_ty), _) = left else { codegen_unreachable!(ctx) };
Some(ref ty) => ctx.unifier.get_representative(*ty), let left_ty = ctx.unifier.get_representative(left_ty);
None => return Err("Left type is None".to_string()),
}; let (Some(right_ty), _) = comparators[0] else { codegen_unreachable!(ctx) };
let right_ty = match comparators[0].0 { let right_ty = ctx.unifier.get_representative(right_ty);
Some(ref ty) => ctx.unifier.get_representative(*ty),
None => return Err("Right type is None".to_string()),
};
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
// Safely unwrap the left and right operands
let (left_ty_opt, lhs) = left; let (left_ty_opt, lhs) = left;
let left_ty = match left_ty_opt { let left_ty = match left_ty_opt {
Some(ty) => ctx.unifier.get_representative(ty), Some(ty) => ctx.unifier.get_representative(ty),
None => return Err("Left type is None".to_string()), None => codegen_unreachable!(ctx),
}; };
let (right_ty_opt, rhs) = match comparators.first().copied() { let (right_ty_opt, rhs) = match comparators.first().copied() {
Some((Some(ty), val)) => (Some(ty), val), Some((Some(ty), val)) => (Some(ty), val),
Some((None, _)) | None => { Some((None, _)) | None => {
return Err("Comparator type is None".to_string()); codegen_unreachable!(ctx);
} }
}; };
let right_ty = match right_ty_opt { let right_ty = match right_ty_opt {
Some(ty) => ctx.unifier.get_representative(ty), Some(ty) => ctx.unifier.get_representative(ty),
None => return Err("Right type is None".to_string()), None => codegen_unreachable!(ctx),
}; };
let op = ops[0]; let op = ops[0];
@ -2014,52 +2009,36 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
} }
} }
// Safely unwrap the left and first comparator operands let (Some(left_ty), lhs_val) = left else { codegen_unreachable!(ctx) };
let (left_ty_opt, lhs_val) = left; let left_ty = ctx.unifier.get_representative(left_ty);
let left_ty = match left_ty_opt {
Some(ty) => ctx.unifier.get_representative(ty),
None => return Err("Left type is None".to_string()),
};
let (right_ty_opt, rhs_val) = match comparators.first().copied() { let (Some(right_ty), rhs_val) = comparators.first().copied().unwrap() else {
Some((Some(ty), val)) => (Some(ty), val), codegen_unreachable!(ctx)
Some((None, _)) | None => {
return Err("Comparator type is None".to_string());
}
};
let right_ty = match right_ty_opt {
Some(ty) => ctx.unifier.get_representative(ty),
None => return Err("Right type is None".to_string()),
}; };
let right_ty = ctx.unifier.get_representative(right_ty);
// Handle string comparisons
if ctx.unifier.unioned(left_ty, ctx.primitives.str) if ctx.unifier.unioned(left_ty, ctx.primitives.str)
&& ctx.unifier.unioned(right_ty, ctx.primitives.str) && ctx.unifier.unioned(right_ty, ctx.primitives.str)
{ {
// Only handle == and != for strings here
if ops.len() == 1 && (ops[0] == ast::Cmpop::Eq || ops[0] == ast::Cmpop::NotEq) { if ops.len() == 1 && (ops[0] == ast::Cmpop::Eq || ops[0] == ast::Cmpop::NotEq) {
// Extract string data
let lhs_struct = lhs_val.into_struct_value(); let lhs_struct = lhs_val.into_struct_value();
let lhs_ptr = match ctx.builder.build_extract_value(lhs_struct, 0, "lhs_ptr") { let lhs_ptr = ctx
Ok(val) => val.into_pointer_value(), .builder
Err(e) => return Err(format!("Failed to extract lhs_ptr: {e:?}")), .build_extract_value(lhs_struct, 0, "lhs_ptr")
}; .unwrap()
let lhs_len = match ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len") { .into_pointer_value();
Ok(val) => val.into_int_value(), let lhs_len =
Err(e) => return Err(format!("Failed to extract lhs_len: {e:?}")), ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len").unwrap().into_int_value();
};
let rhs_struct = rhs_val.into_struct_value(); let rhs_struct = rhs_val.into_struct_value();
let rhs_ptr = match ctx.builder.build_extract_value(rhs_struct, 0, "rhs_ptr") { let rhs_ptr = ctx
Ok(val) => val.into_pointer_value(), .builder
Err(e) => return Err(format!("Failed to extract rhs_ptr: {e:?}")), .build_extract_value(rhs_struct, 0, "rhs_ptr")
}; .unwrap()
let rhs_len = match ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len") { .into_pointer_value();
Ok(val) => val.into_int_value(), let rhs_len =
Err(e) => return Err(format!("Failed to extract rhs_len: {e:?}")), ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len").unwrap().into_int_value();
};
// Get or declare nac3_str_eq function
let str_eq_fn = if let Some(fun) = ctx.module.get_function("nac3_str_eq") { let str_eq_fn = if let Some(fun) = ctx.module.get_function("nac3_str_eq") {
fun fun
} else { } else {
@ -2073,7 +2052,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ctx.module.add_function("nac3_str_eq", fn_type, None) ctx.module.add_function("nac3_str_eq", fn_type, None)
}; };
// Call nac3_str_eq(lhs_ptr, lhs_len, rhs_ptr, rhs_len)
let call_site = ctx let call_site = ctx
.builder .builder
.build_call( .build_call(
@ -2083,35 +2061,23 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
) )
.expect("Failed to build call to nac3_str_eq"); .expect("Failed to build call to nac3_str_eq");
// The function returns a bool (i1 in LLVM)
let eq_result = match call_site.try_as_basic_value() { let eq_result = match call_site.try_as_basic_value() {
Either::Left(inkwell::values::BasicValueEnum::IntValue(val)) => val, Either::Left(inkwell::values::BasicValueEnum::IntValue(val)) => val,
Either::Left(_) | Either::Right(_) => { Either::Left(_) | Either::Right(_) => codegen_unreachable!(ctx),
return Err("nac3_str_eq did not return an i1".to_string())
}
}; };
// Convert i1 to i8 if NAC3 bool is i8 let eq_i8 =
let eq_i8 = match ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8") ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8").unwrap();
{
Ok(val) => val,
Err(e) => return Err(format!("Failed to extend i1 to i8: {e:?}")),
};
// If the operation is NotEq, invert the result
let final_result = if ops[0] == ast::Cmpop::NotEq { let final_result = if ops[0] == ast::Cmpop::NotEq {
match ctx.builder.build_not(eq_i8, "neq") { ctx.builder.build_not(eq_i8, "neq").unwrap()
Ok(val) => val,
Err(e) => return Err(format!("Failed to invert eq_i8 for NotEq: {e:?}")),
}
} else { } else {
eq_i8 eq_i8
}; };
// Return as ValueEnum::Dynamic
return Ok(Some(ValueEnum::Dynamic(final_result.into()))); return Ok(Some(ValueEnum::Dynamic(final_result.into())));
} }
return Err(format!("Operator '{:?}' not supported for strings", ops[0])); codegen_unreachable!(ctx);
} }
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
@ -2348,7 +2314,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
} }
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
todo!("Only __eq__ and __ne__ is implemented for lists") codegen_unreachable!(ctx, "Only __eq__ and __ne__ supported for this type")
} }
let left_val = let left_val =
@ -2472,10 +2438,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
gen_list_cmpop(generator, ctx)? gen_list_cmpop(generator, ctx)?
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) { } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) {
let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else { let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else {
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) codegen_unreachable!(ctx)
}; };
let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else { let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else {
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) codegen_unreachable!(ctx)
}; };
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
@ -2600,10 +2566,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ctx.ctx.bool_type().get_poison() ctx.ctx.bool_type().get_poison()
} else { } else {
return Err(format!("'{}' not supported between instances of '{}' and '{}'", codegen_unreachable!(ctx)
op.op_info().symbol,
ctx.unifier.stringify(left_ty),
ctx.unifier.stringify(right_ty)))
}; };
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))

View File

@ -25,28 +25,11 @@ def str_eq():
# Different lengths # Different lengths
output_bool("abc" == "abcde") output_bool("abc" == "abcde")
# Case sensitivity
output_bool("Hello, World!" == "Hello, World!")
output_bool("CaseSensitive" == "casesensitive")
# Leading and trailing spaces # Leading and trailing spaces
output_bool(" leading space" == "leading space") output_bool(" leading space" == "leading space")
output_bool("trailing space " == "trailing space") output_bool("trailing space " == "trailing space")
output_bool(" " == " ") output_bool(" " == " ")
# Special characters and punctuation
output_bool("special@#%$^&*()_+{}|:<>?`~chars" == "special@#%$^&*()_+{}|:<>?`~chars")
# Unicode strings
output_bool("café" == "café") # Same accented character
output_bool("café" == "cafe") # Accented vs unaccented
# Strings with newline and tab
output_bool("line1\nline2" == "line1\nline2")
output_bool("tab\tseparated" == "tab\tseparated")
output_bool("line1\nline2" == "line1 line2")
def str_ne(): def str_ne():
# Basic cases # Basic cases
output_bool("" != "") output_bool("" != "")
@ -68,27 +51,11 @@ def str_ne():
# Different lengths # Different lengths
output_bool("abc" != "abcde") output_bool("abc" != "abcde")
# Case sensitivity
output_bool("Hello, World!" != "Hello, World!")
output_bool("CaseSensitive" != "casesensitive")
# Leading and trailing spaces # Leading and trailing spaces
output_bool(" leading space" != "leading space") output_bool(" leading space" != "leading space")
output_bool("trailing space " != "trailing space") output_bool("trailing space " != "trailing space")
output_bool(" " != " ") output_bool(" " != " ")
# Special characters and punctuation
output_bool("special@#%$^&*()_+{}|:<>?`~chars" != "special@#%$^&*()_+{}|:<>?`~chars")
# Unicode strings
output_bool("café" != "café")
output_bool("café" != "cafe")
# Strings with newline and tab
output_bool("line1\nline2" != "line1\nline2")
output_bool("tab\tseparated" != "tab\tseparated")
output_bool("line1\nline2" != "line1 line2")
def run() -> int32: def run() -> int32:
str_eq() str_eq()
str_ne() str_ne()