From c37c7e8975ea098b1b2ddcc1f2cab6138b958533 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 7 Feb 2025 10:39:21 +0800 Subject: [PATCH] [core] codegen/expr: Simplify `gen_*_expr_with_values` return value These functions always return `BasicValueEnum` because they operate on `BasicValueEnum`s, and they also always return a value. --- nac3core/src/codegen/expr.rs | 88 ++++++------------- nac3core/src/codegen/values/ndarray/matmul.rs | 8 +- 2 files changed, 30 insertions(+), 66 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 53aa5f14..986ed992 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1319,7 +1319,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op: Binop, right: (&Option, BasicValueEnum<'ctx>), loc: Location, -) -> Result>, String> { +) -> Result, String> { let (left_ty, left_val) = left; let (right_ty, right_val) = right; @@ -1330,14 +1330,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( // which would be unchanged until further unification, which we would never do // when doing code generation for function instances if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, true)) } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, false)) } else if [Operator::LShift, Operator::RShift].contains(&op.base) { let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); - Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into())) + Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed)) } else if ty1 == ty2 && ctx.primitives.float == ty1 { - Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into())) + Ok(ctx.gen_float_ops(op.base, left_val, right_val)) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { // Pow is the only operator that would pass typecheck between float and int assert_eq!(op.base, Operator::Pow); @@ -1347,7 +1347,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( right_val.into_int_value(), Some("f_pow_i"), ); - Ok(Some(res.into())) + Ok(res.into()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { @@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().const_zero(), ); - Ok(Some(new_list.as_abi_value(ctx).into())) + Ok(new_list.as_abi_value(ctx).into()) } Operator::Mult => { @@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), )?; - Ok(Some(new_list.as_abi_value(ctx).into())) + Ok(new_list.as_abi_value(ctx).into()) } _ => todo!("Operator not supported"), @@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let result = left .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) .split_unsized(generator, ctx); - Ok(Some(result.to_basic_value_enum().into())) + Ok(result.to_basic_value_enum()) } else { // For other operations, they are all elementwise operations. @@ -1594,14 +1594,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( op, (&Some(ty2_dtype), right_value), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, common_dtype)?; + )?; Ok(result) }) .unwrap(); - Ok(Some(result.as_abi_value(ctx).into())) + Ok(result.as_abi_value(ctx).into()) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1650,7 +1648,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( (&signature, fun_id), vec![(None, right_val.into())], ) - .map(|f| f.map(Into::into)) + .map(Option::unwrap) + .map(BasicValueEnum::into) } } @@ -1688,6 +1687,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( (&right.custom, right_val), loc, ) + .map(|res| Some(res.into())) } /// Generates LLVM IR for a unary operator expression using the [`Type`] and @@ -1697,11 +1697,11 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, '_>, op: ast::Unaryop, operand: (&Option, BasicValueEnum<'ctx>), -) -> Result>, String> { +) -> Result, String> { let (ty, val) = operand; let ty = ctx.unifier.get_representative(ty.unwrap()); - Ok(Some(if ty == ctx.primitives.bool { + Ok(if ty == ctx.primitives.bool { let val = val.into_int_value(); if op == ast::Unaryop::Not { let not = ctx @@ -1722,7 +1722,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(), ), )? - .unwrap() } } else if [ ctx.primitives.int32, @@ -1791,16 +1790,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( ctx, NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, |generator, ctx, scalar| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? - .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype)) - .unwrap() + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar)) }, )?; mapped_ndarray.as_abi_value(ctx).into() } else { unimplemented!() - })) + }) } /// Generates LLVM IR for a unary operator expression. @@ -1820,6 +1817,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>( }; gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val)) + .map(|res| Some(res.into())) } /// Generates LLVM IR for a comparison operator expression using the [`Type`] and @@ -1830,7 +1828,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( left: (Option, BasicValueEnum<'ctx>), ops: &[ast::Cmpop], comparators: &[(Option, BasicValueEnum<'ctx>)], -) -> Result>, String> { +) -> Result, String> { debug_assert_eq!(comparators.len(), ops.len()); if comparators.len() == 1 { @@ -1872,19 +1870,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( (Some(left_ty_dtype), left_scalar), &[op], &[(Some(right_ty_dtype), right_scalar)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, )?; Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) }, )?; - return Ok(Some(result_ndarray.as_abi_value(ctx).into())); + return Ok(result_ndarray.as_abi_value(ctx).into()); } } @@ -2007,13 +1999,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx, Unaryop::Not, (&Some(ctx.primitives.bool), result.into()), - ) - .transpose() - .unwrap() - .and_then(|res| { - res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value() + )?.into_int_value() } else { result } @@ -2116,9 +2102,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[Cmpop::Eq], &[(Some(right_elem_ty), right)], )? - .unwrap() - .to_basic_value_enum(ctx, generator, ctx.primitives.bool) - .unwrap() .into_int_value(); gen_if_callback( @@ -2167,8 +2150,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( Unaryop::Not, (&Some(ctx.primitives.bool), acc.into()), )? - .unwrap() - .to_basic_value_enum(ctx, generator, ctx.primitives.bool)? .into_int_value() } else { acc @@ -2256,12 +2237,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( &[op], &[(Some(right_ty), right_elem)], ) - .transpose() - .unwrap() - .and_then(|v| { - v.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value(); + .map(BasicValueEnum::into_int_value)?; Ok(ctx.builder.build_not( generator.bool_to_i1(ctx, cmp), @@ -2301,14 +2277,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( generator, ctx, Unaryop::Not, - (&Some(ctx.primitives.bool), cmp_phi.into()) - ) - .transpose() - .unwrap() - .and_then(|res| { - res.to_basic_value_enum(ctx, generator, ctx.primitives.bool) - })? - .into_int_value() + (&Some(ctx.primitives.bool), cmp_phi.into()), + )?.into_int_value() } else { cmp_phi } @@ -2333,12 +2303,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }; Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) - })?; + })?.unwrap(); - Ok(Some(match cmp_val { - Some(v) => v.into(), - None => return Ok(None), - })) + Ok(cmp_val.into()) } /// Generates LLVM IR for a comparison operator expression. @@ -2385,6 +2352,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ops, comparator_vals.as_slice(), ) + .map(|res| Some(res.into())) } /// See [`CodeGenerator::gen_expr`]. diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index f12d36c1..cc8d059a 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -213,9 +213,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( Binop::normal(Operator::Mult), (&Some(rhs_dtype), b_kj), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, dst_dtype)?; + )?; // dst_[...]ij += x let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); @@ -226,9 +224,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( Binop::normal(Operator::Add), (&Some(dst_dtype), x), ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, dst_dtype)?; + )?; ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); Ok(())