[core] codegen/expr: Simplify gen_*_expr_with_values return value

These functions always return `BasicValueEnum` because they operate on
`BasicValueEnum`s, and they also always return a value.
This commit is contained in:
David Mak 2025-02-07 10:39:21 +08:00
parent 0d8cb909dd
commit c37c7e8975
2 changed files with 30 additions and 66 deletions

View File

@ -1319,7 +1319,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op: Binop, op: Binop,
right: (&Option<Type>, BasicValueEnum<'ctx>), right: (&Option<Type>, BasicValueEnum<'ctx>),
loc: Location, loc: Location,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
let (left_ty, left_val) = left; let (left_ty, left_val) = left;
let (right_ty, right_val) = right; let (right_ty, right_val) = right;
@ -1330,14 +1330,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
// which would be unchanged until further unification, which we would never do // which would be unchanged until further unification, which we would never do
// when doing code generation for function instances // when doing code generation for function instances
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into())) Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, true))
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into())) Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, false))
} else if [Operator::LShift, Operator::RShift].contains(&op.base) { } else if [Operator::LShift, Operator::RShift].contains(&op.base) {
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1); let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into())) Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed))
} else if ty1 == ty2 && ctx.primitives.float == ty1 { } else if ty1 == ty2 && ctx.primitives.float == ty1 {
Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into())) Ok(ctx.gen_float_ops(op.base, left_val, right_val))
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
// Pow is the only operator that would pass typecheck between float and int // Pow is the only operator that would pass typecheck between float and int
assert_eq!(op.base, Operator::Pow); assert_eq!(op.base, Operator::Pow);
@ -1347,7 +1347,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
right_val.into_int_value(), right_val.into_int_value(),
Some("f_pow_i"), Some("f_pow_i"),
); );
Ok(Some(res.into())) Ok(res.into())
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
{ {
@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
ctx.ctx.bool_type().const_zero(), ctx.ctx.bool_type().const_zero(),
); );
Ok(Some(new_list.as_abi_value(ctx).into())) Ok(new_list.as_abi_value(ctx).into())
} }
Operator::Mult => { Operator::Mult => {
@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
)?; )?;
Ok(Some(new_list.as_abi_value(ctx).into())) Ok(new_list.as_abi_value(ctx).into())
} }
_ => todo!("Operator not supported"), _ => todo!("Operator not supported"),
@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let result = left let result = left
.matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out)) .matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out))
.split_unsized(generator, ctx); .split_unsized(generator, ctx);
Ok(Some(result.to_basic_value_enum().into())) Ok(result.to_basic_value_enum())
} else { } else {
// For other operations, they are all elementwise operations. // For other operations, they are all elementwise operations.
@ -1594,14 +1594,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op, op,
(&Some(ty2_dtype), right_value), (&Some(ty2_dtype), right_value),
ctx.current_loc, ctx.current_loc,
)? )?;
.unwrap()
.to_basic_value_enum(ctx, generator, common_dtype)?;
Ok(result) Ok(result)
}) })
.unwrap(); .unwrap();
Ok(Some(result.as_abi_value(ctx).into())) Ok(result.as_abi_value(ctx).into())
} }
} else { } else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
@ -1650,7 +1648,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
(&signature, fun_id), (&signature, fun_id),
vec![(None, right_val.into())], vec![(None, right_val.into())],
) )
.map(|f| f.map(Into::into)) .map(Option::unwrap)
.map(BasicValueEnum::into)
} }
} }
@ -1688,6 +1687,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
(&right.custom, right_val), (&right.custom, right_val),
loc, loc,
) )
.map(|res| Some(res.into()))
} }
/// Generates LLVM IR for a unary operator expression using the [`Type`] and /// Generates LLVM IR for a unary operator expression using the [`Type`] and
@ -1697,11 +1697,11 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
op: ast::Unaryop, op: ast::Unaryop,
operand: (&Option<Type>, BasicValueEnum<'ctx>), operand: (&Option<Type>, BasicValueEnum<'ctx>),
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
let (ty, val) = operand; let (ty, val) = operand;
let ty = ctx.unifier.get_representative(ty.unwrap()); let ty = ctx.unifier.get_representative(ty.unwrap());
Ok(Some(if ty == ctx.primitives.bool { Ok(if ty == ctx.primitives.bool {
let val = val.into_int_value(); let val = val.into_int_value();
if op == ast::Unaryop::Not { if op == ast::Unaryop::Not {
let not = ctx let not = ctx
@ -1722,7 +1722,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(), ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(),
), ),
)? )?
.unwrap()
} }
} else if [ } else if [
ctx.primitives.int32, ctx.primitives.int32,
@ -1791,16 +1790,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
ctx, ctx,
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|generator, ctx, scalar| { |generator, ctx, scalar| {
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))
.map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
.unwrap()
}, },
)?; )?;
mapped_ndarray.as_abi_value(ctx).into() mapped_ndarray.as_abi_value(ctx).into()
} else { } else {
unimplemented!() unimplemented!()
})) })
} }
/// Generates LLVM IR for a unary operator expression. /// Generates LLVM IR for a unary operator expression.
@ -1820,6 +1817,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
}; };
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val)) gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
.map(|res| Some(res.into()))
} }
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and /// Generates LLVM IR for a comparison operator expression using the [`Type`] and
@ -1830,7 +1828,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
left: (Option<Type>, BasicValueEnum<'ctx>), left: (Option<Type>, BasicValueEnum<'ctx>),
ops: &[ast::Cmpop], ops: &[ast::Cmpop],
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)], comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
debug_assert_eq!(comparators.len(), ops.len()); debug_assert_eq!(comparators.len(), ops.len());
if comparators.len() == 1 { if comparators.len() == 1 {
@ -1872,19 +1870,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
(Some(left_ty_dtype), left_scalar), (Some(left_ty_dtype), left_scalar),
&[op], &[op],
&[(Some(right_ty_dtype), right_scalar)], &[(Some(right_ty_dtype), right_scalar)],
)?
.unwrap()
.to_basic_value_enum(
ctx,
generator,
ctx.primitives.bool,
)?; )?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
}, },
)?; )?;
return Ok(Some(result_ndarray.as_abi_value(ctx).into())); return Ok(result_ndarray.as_abi_value(ctx).into());
} }
} }
@ -2007,13 +1999,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ctx, ctx,
Unaryop::Not, Unaryop::Not,
(&Some(ctx.primitives.bool), result.into()), (&Some(ctx.primitives.bool), result.into()),
) )?.into_int_value()
.transpose()
.unwrap()
.and_then(|res| {
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
})?
.into_int_value()
} else { } else {
result result
} }
@ -2116,9 +2102,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
&[Cmpop::Eq], &[Cmpop::Eq],
&[(Some(right_elem_ty), right)], &[(Some(right_elem_ty), right)],
)? )?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
.unwrap()
.into_int_value(); .into_int_value();
gen_if_callback( gen_if_callback(
@ -2167,8 +2150,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
Unaryop::Not, Unaryop::Not,
(&Some(ctx.primitives.bool), acc.into()), (&Some(ctx.primitives.bool), acc.into()),
)? )?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?
.into_int_value() .into_int_value()
} else { } else {
acc acc
@ -2256,12 +2237,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
&[op], &[op],
&[(Some(right_ty), right_elem)], &[(Some(right_ty), right_elem)],
) )
.transpose() .map(BasicValueEnum::into_int_value)?;
.unwrap()
.and_then(|v| {
v.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
})?
.into_int_value();
Ok(ctx.builder.build_not( Ok(ctx.builder.build_not(
generator.bool_to_i1(ctx, cmp), generator.bool_to_i1(ctx, cmp),
@ -2301,14 +2277,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
generator, generator,
ctx, ctx,
Unaryop::Not, Unaryop::Not,
(&Some(ctx.primitives.bool), cmp_phi.into()) (&Some(ctx.primitives.bool), cmp_phi.into()),
) )?.into_int_value()
.transpose()
.unwrap()
.and_then(|res| {
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
})?
.into_int_value()
} else { } else {
cmp_phi cmp_phi
} }
@ -2333,12 +2303,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}; };
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
})?; })?.unwrap();
Ok(Some(match cmp_val { Ok(cmp_val.into())
Some(v) => v.into(),
None => return Ok(None),
}))
} }
/// Generates LLVM IR for a comparison operator expression. /// Generates LLVM IR for a comparison operator expression.
@ -2385,6 +2352,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
ops, ops,
comparator_vals.as_slice(), comparator_vals.as_slice(),
) )
.map(|res| Some(res.into()))
} }
/// See [`CodeGenerator::gen_expr`]. /// See [`CodeGenerator::gen_expr`].

View File

@ -213,9 +213,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
Binop::normal(Operator::Mult), Binop::normal(Operator::Mult),
(&Some(rhs_dtype), b_kj), (&Some(rhs_dtype), b_kj),
ctx.current_loc, ctx.current_loc,
)? )?;
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
// dst_[...]ij += x // dst_[...]ij += x
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
@ -226,9 +224,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
Binop::normal(Operator::Add), Binop::normal(Operator::Add),
(&Some(dst_dtype), x), (&Some(dst_dtype), x),
ctx.current_loc, ctx.current_loc,
)? )?;
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
Ok(()) Ok(())