[core] codegen/expr: Simplify gen_*_expr_with_values return value

These functions always return `BasicValueEnum` because they operate on
`BasicValueEnum`s, and they also always return a value.
This commit is contained in:
David Mak 2025-02-07 10:39:21 +08:00
parent 0d8cb909dd
commit c37c7e8975
2 changed files with 30 additions and 66 deletions

View File

@ -1319,7 +1319,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op: Binop,
right: (&Option<Type>, BasicValueEnum<'ctx>),
loc: Location,
) -> Result<Option<ValueEnum<'ctx>>, String> {
) -> Result<BasicValueEnum<'ctx>, String> {
let (left_ty, left_val) = left;
let (right_ty, right_val) = right;
@ -1330,14 +1330,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
// which would be unchanged until further unification, which we would never do
// when doing code generation for function instances
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into()))
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, true))
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into()))
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, false))
} else if [Operator::LShift, Operator::RShift].contains(&op.base) {
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into()))
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed))
} else if ty1 == ty2 && ctx.primitives.float == ty1 {
Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into()))
Ok(ctx.gen_float_ops(op.base, left_val, right_val))
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
// Pow is the only operator that would pass typecheck between float and int
assert_eq!(op.base, Operator::Pow);
@ -1347,7 +1347,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
right_val.into_int_value(),
Some("f_pow_i"),
);
Ok(Some(res.into()))
Ok(res.into())
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
{
@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
ctx.ctx.bool_type().const_zero(),
);
Ok(Some(new_list.as_abi_value(ctx).into()))
Ok(new_list.as_abi_value(ctx).into())
}
Operator::Mult => {
@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
llvm_usize.const_int(1, false),
)?;
Ok(Some(new_list.as_abi_value(ctx).into()))
Ok(new_list.as_abi_value(ctx).into())
}
_ => todo!("Operator not supported"),
@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let result = left
.matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out))
.split_unsized(generator, ctx);
Ok(Some(result.to_basic_value_enum().into()))
Ok(result.to_basic_value_enum())
} else {
// For other operations, they are all elementwise operations.
@ -1594,14 +1594,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op,
(&Some(ty2_dtype), right_value),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, common_dtype)?;
)?;
Ok(result)
})
.unwrap();
Ok(Some(result.as_abi_value(ctx).into()))
Ok(result.as_abi_value(ctx).into())
}
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
@ -1650,7 +1648,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
(&signature, fun_id),
vec![(None, right_val.into())],
)
.map(|f| f.map(Into::into))
.map(Option::unwrap)
.map(BasicValueEnum::into)
}
}
@ -1688,6 +1687,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
(&right.custom, right_val),
loc,
)
.map(|res| Some(res.into()))
}
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
@ -1697,11 +1697,11 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>,
op: ast::Unaryop,
operand: (&Option<Type>, BasicValueEnum<'ctx>),
) -> Result<Option<ValueEnum<'ctx>>, String> {
) -> Result<BasicValueEnum<'ctx>, String> {
let (ty, val) = operand;
let ty = ctx.unifier.get_representative(ty.unwrap());
Ok(Some(if ty == ctx.primitives.bool {
Ok(if ty == ctx.primitives.bool {
let val = val.into_int_value();
if op == ast::Unaryop::Not {
let not = ctx
@ -1722,7 +1722,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(),
),
)?
.unwrap()
}
} else if [
ctx.primitives.int32,
@ -1791,16 +1790,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
ctx,
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|generator, ctx, scalar| {
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))?
.map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
.unwrap()
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))
},
)?;
mapped_ndarray.as_abi_value(ctx).into()
} else {
unimplemented!()
}))
})
}
/// Generates LLVM IR for a unary operator expression.
@ -1820,6 +1817,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
};
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
.map(|res| Some(res.into()))
}
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
@ -1830,7 +1828,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
left: (Option<Type>, BasicValueEnum<'ctx>),
ops: &[ast::Cmpop],
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
) -> Result<Option<ValueEnum<'ctx>>, String> {
) -> Result<BasicValueEnum<'ctx>, String> {
debug_assert_eq!(comparators.len(), ops.len());
if comparators.len() == 1 {
@ -1872,19 +1870,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
(Some(left_ty_dtype), left_scalar),
&[op],
&[(Some(right_ty_dtype), right_scalar)],
)?
.unwrap()
.to_basic_value_enum(
ctx,
generator,
ctx.primitives.bool,
)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
return Ok(Some(result_ndarray.as_abi_value(ctx).into()));
return Ok(result_ndarray.as_abi_value(ctx).into());
}
}
@ -2007,13 +1999,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ctx,
Unaryop::Not,
(&Some(ctx.primitives.bool), result.into()),
)
.transpose()
.unwrap()
.and_then(|res| {
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
})?
.into_int_value()
)?.into_int_value()
} else {
result
}
@ -2116,9 +2102,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
&[Cmpop::Eq],
&[(Some(right_elem_ty), right)],
)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
.unwrap()
.into_int_value();
gen_if_callback(
@ -2167,8 +2150,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
Unaryop::Not,
(&Some(ctx.primitives.bool), acc.into()),
)?
.unwrap()
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?
.into_int_value()
} else {
acc
@ -2256,12 +2237,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
&[op],
&[(Some(right_ty), right_elem)],
)
.transpose()
.unwrap()
.and_then(|v| {
v.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
})?
.into_int_value();
.map(BasicValueEnum::into_int_value)?;
Ok(ctx.builder.build_not(
generator.bool_to_i1(ctx, cmp),
@ -2301,14 +2277,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
generator,
ctx,
Unaryop::Not,
(&Some(ctx.primitives.bool), cmp_phi.into())
)
.transpose()
.unwrap()
.and_then(|res| {
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
})?
.into_int_value()
(&Some(ctx.primitives.bool), cmp_phi.into()),
)?.into_int_value()
} else {
cmp_phi
}
@ -2333,12 +2303,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
};
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
})?;
})?.unwrap();
Ok(Some(match cmp_val {
Some(v) => v.into(),
None => return Ok(None),
}))
Ok(cmp_val.into())
}
/// Generates LLVM IR for a comparison operator expression.
@ -2385,6 +2352,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
ops,
comparator_vals.as_slice(),
)
.map(|res| Some(res.into()))
}
/// See [`CodeGenerator::gen_expr`].

View File

@ -213,9 +213,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
Binop::normal(Operator::Mult),
(&Some(rhs_dtype), b_kj),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
)?;
// dst_[...]ij += x
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
@ -226,9 +224,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
Binop::normal(Operator::Add),
(&Some(dst_dtype), x),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, dst_dtype)?;
)?;
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
Ok(())