diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index cd0544b..632619e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1897,38 +1897,33 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ) -> Result>, String> { debug_assert_eq!(comparators.len(), ops.len()); - // Handle NDArray comparisons first if comparators.len() == 1 { - let left_ty = match left.0 { - Some(ref ty) => ctx.unifier.get_representative(*ty), - None => return Err("Left type is None".to_string()), - }; - let right_ty = match comparators[0].0 { - Some(ref ty) => ctx.unifier.get_representative(*ty), - None => return Err("Right type is None".to_string()), - }; + let (Some(left_ty), _) = left else { codegen_unreachable!(ctx) }; + let left_ty = ctx.unifier.get_representative(left_ty); + + let (Some(right_ty), _) = comparators[0] else { codegen_unreachable!(ctx) }; + let right_ty = ctx.unifier.get_representative(right_ty); if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); - // Safely unwrap the left and right operands let (left_ty_opt, lhs) = left; let left_ty = match left_ty_opt { Some(ty) => ctx.unifier.get_representative(ty), - None => return Err("Left type is None".to_string()), + None => codegen_unreachable!(ctx), }; let (right_ty_opt, rhs) = match comparators.first().copied() { Some((Some(ty), val)) => (Some(ty), val), Some((None, _)) | None => { - return Err("Comparator type is None".to_string()); + codegen_unreachable!(ctx); } }; let right_ty = match right_ty_opt { Some(ty) => ctx.unifier.get_representative(ty), - None => return Err("Right type is None".to_string()), + None => codegen_unreachable!(ctx), }; let op = ops[0]; @@ -2014,52 +2009,36 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } } - // Safely unwrap the left and first comparator operands - let (left_ty_opt, lhs_val) = left; - let left_ty = match left_ty_opt { - Some(ty) => ctx.unifier.get_representative(ty), - None => return Err("Left type is None".to_string()), - }; + let (Some(left_ty), lhs_val) = left else { codegen_unreachable!(ctx) }; + let left_ty = ctx.unifier.get_representative(left_ty); - let (right_ty_opt, rhs_val) = match comparators.first().copied() { - Some((Some(ty), val)) => (Some(ty), val), - Some((None, _)) | None => { - return Err("Comparator type is None".to_string()); - } - }; - let right_ty = match right_ty_opt { - Some(ty) => ctx.unifier.get_representative(ty), - None => return Err("Right type is None".to_string()), + let (Some(right_ty), rhs_val) = comparators.first().copied().unwrap() else { + codegen_unreachable!(ctx) }; + let right_ty = ctx.unifier.get_representative(right_ty); - // Handle string comparisons if ctx.unifier.unioned(left_ty, ctx.primitives.str) && ctx.unifier.unioned(right_ty, ctx.primitives.str) { - // Only handle == and != for strings here if ops.len() == 1 && (ops[0] == ast::Cmpop::Eq || ops[0] == ast::Cmpop::NotEq) { - // Extract string data let lhs_struct = lhs_val.into_struct_value(); - let lhs_ptr = match ctx.builder.build_extract_value(lhs_struct, 0, "lhs_ptr") { - Ok(val) => val.into_pointer_value(), - Err(e) => return Err(format!("Failed to extract lhs_ptr: {e:?}")), - }; - let lhs_len = match ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len") { - Ok(val) => val.into_int_value(), - Err(e) => return Err(format!("Failed to extract lhs_len: {e:?}")), - }; + let lhs_ptr = ctx + .builder + .build_extract_value(lhs_struct, 0, "lhs_ptr") + .unwrap() + .into_pointer_value(); + let lhs_len = + ctx.builder.build_extract_value(lhs_struct, 1, "lhs_len").unwrap().into_int_value(); let rhs_struct = rhs_val.into_struct_value(); - let rhs_ptr = match ctx.builder.build_extract_value(rhs_struct, 0, "rhs_ptr") { - Ok(val) => val.into_pointer_value(), - Err(e) => return Err(format!("Failed to extract rhs_ptr: {e:?}")), - }; - let rhs_len = match ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len") { - Ok(val) => val.into_int_value(), - Err(e) => return Err(format!("Failed to extract rhs_len: {e:?}")), - }; + let rhs_ptr = ctx + .builder + .build_extract_value(rhs_struct, 0, "rhs_ptr") + .unwrap() + .into_pointer_value(); + let rhs_len = + ctx.builder.build_extract_value(rhs_struct, 1, "rhs_len").unwrap().into_int_value(); - // Get or declare nac3_str_eq function let str_eq_fn = if let Some(fun) = ctx.module.get_function("nac3_str_eq") { fun } else { @@ -2073,7 +2052,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx.module.add_function("nac3_str_eq", fn_type, None) }; - // Call nac3_str_eq(lhs_ptr, lhs_len, rhs_ptr, rhs_len) let call_site = ctx .builder .build_call( @@ -2083,35 +2061,23 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ) .expect("Failed to build call to nac3_str_eq"); - // The function returns a bool (i1 in LLVM) let eq_result = match call_site.try_as_basic_value() { Either::Left(inkwell::values::BasicValueEnum::IntValue(val)) => val, - Either::Left(_) | Either::Right(_) => { - return Err("nac3_str_eq did not return an i1".to_string()) - } + Either::Left(_) | Either::Right(_) => codegen_unreachable!(ctx), }; - // Convert i1 to i8 if NAC3 bool is i8 - let eq_i8 = match ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8") - { - Ok(val) => val, - Err(e) => return Err(format!("Failed to extend i1 to i8: {e:?}")), - }; + let eq_i8 = + ctx.builder.build_int_z_extend(eq_result, ctx.ctx.i8_type(), "eq_i8").unwrap(); - // If the operation is NotEq, invert the result let final_result = if ops[0] == ast::Cmpop::NotEq { - match ctx.builder.build_not(eq_i8, "neq") { - Ok(val) => val, - Err(e) => return Err(format!("Failed to invert eq_i8 for NotEq: {e:?}")), - } + ctx.builder.build_not(eq_i8, "neq").unwrap() } else { eq_i8 }; - // Return as ValueEnum::Dynamic return Ok(Some(ValueEnum::Dynamic(final_result.into()))); } - return Err(format!("Operator '{:?}' not supported for strings", ops[0])); + codegen_unreachable!(ctx); } let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) @@ -2348,7 +2314,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { - todo!("Only __eq__ and __ne__ is implemented for lists") + codegen_unreachable!(ctx, "Only __eq__ and __ne__ supported for this type") } let left_val = @@ -2472,10 +2438,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( gen_list_cmpop(generator, ctx)? } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) { let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else { - return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + codegen_unreachable!(ctx) }; let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else { - return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + codegen_unreachable!(ctx) }; if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { @@ -2600,10 +2566,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().get_poison() } else { - return Err(format!("'{}' not supported between instances of '{}' and '{}'", - op.op_info().symbol, - ctx.unifier.stringify(left_ty), - ctx.unifier.stringify(right_ty))) + codegen_unreachable!(ctx) }; Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) diff --git a/nac3standalone/demo/src/str.py b/nac3standalone/demo/src/str.py index 756f1d8..a9d1dcb 100644 --- a/nac3standalone/demo/src/str.py +++ b/nac3standalone/demo/src/str.py @@ -25,28 +25,11 @@ def str_eq(): # Different lengths output_bool("abc" == "abcde") - # Case sensitivity - output_bool("Hello, World!" == "Hello, World!") - output_bool("CaseSensitive" == "casesensitive") - # Leading and trailing spaces output_bool(" leading space" == "leading space") output_bool("trailing space " == "trailing space") output_bool(" " == " ") - # Special characters and punctuation - output_bool("special@#%$^&*()_+{}|:<>?`~chars" == "special@#%$^&*()_+{}|:<>?`~chars") - - # Unicode strings - output_bool("café" == "café") # Same accented character - output_bool("café" == "cafe") # Accented vs unaccented - - # Strings with newline and tab - output_bool("line1\nline2" == "line1\nline2") - output_bool("tab\tseparated" == "tab\tseparated") - output_bool("line1\nline2" == "line1 line2") - - def str_ne(): # Basic cases output_bool("" != "") @@ -68,27 +51,11 @@ def str_ne(): # Different lengths output_bool("abc" != "abcde") - # Case sensitivity - output_bool("Hello, World!" != "Hello, World!") - output_bool("CaseSensitive" != "casesensitive") - # Leading and trailing spaces output_bool(" leading space" != "leading space") output_bool("trailing space " != "trailing space") output_bool(" " != " ") - # Special characters and punctuation - output_bool("special@#%$^&*()_+{}|:<>?`~chars" != "special@#%$^&*()_+{}|:<>?`~chars") - - # Unicode strings - output_bool("café" != "café") - output_bool("café" != "cafe") - - # Strings with newline and tab - output_bool("line1\nline2" != "line1\nline2") - output_bool("tab\tseparated" != "tab\tseparated") - output_bool("line1\nline2" != "line1 line2") - def run() -> int32: str_eq() str_ne()