[core] codegen/expr: Simplify gen_*_expr_with_values
return value
These functions always return `BasicValueEnum` because they operate on `BasicValueEnum`s, and they also always return a value.
This commit is contained in:
parent
0d8cb909dd
commit
c37c7e8975
@ -1319,7 +1319,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
op: Binop,
|
||||
right: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||
loc: Location,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let (left_ty, left_val) = left;
|
||||
let (right_ty, right_val) = right;
|
||||
|
||||
@ -1330,14 +1330,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
// which would be unchanged until further unification, which we would never do
|
||||
// when doing code generation for function instances
|
||||
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
|
||||
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into()))
|
||||
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, true))
|
||||
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
|
||||
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into()))
|
||||
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, false))
|
||||
} else if [Operator::LShift, Operator::RShift].contains(&op.base) {
|
||||
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
|
||||
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into()))
|
||||
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed))
|
||||
} else if ty1 == ty2 && ctx.primitives.float == ty1 {
|
||||
Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into()))
|
||||
Ok(ctx.gen_float_ops(op.base, left_val, right_val))
|
||||
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
|
||||
// Pow is the only operator that would pass typecheck between float and int
|
||||
assert_eq!(op.base, Operator::Pow);
|
||||
@ -1347,7 +1347,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
right_val.into_int_value(),
|
||||
Some("f_pow_i"),
|
||||
);
|
||||
Ok(Some(res.into()))
|
||||
Ok(res.into())
|
||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
||||
{
|
||||
@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
ctx.ctx.bool_type().const_zero(),
|
||||
);
|
||||
|
||||
Ok(Some(new_list.as_abi_value(ctx).into()))
|
||||
Ok(new_list.as_abi_value(ctx).into())
|
||||
}
|
||||
|
||||
Operator::Mult => {
|
||||
@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
Ok(Some(new_list.as_abi_value(ctx).into()))
|
||||
Ok(new_list.as_abi_value(ctx).into())
|
||||
}
|
||||
|
||||
_ => todo!("Operator not supported"),
|
||||
@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
let result = left
|
||||
.matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out))
|
||||
.split_unsized(generator, ctx);
|
||||
Ok(Some(result.to_basic_value_enum().into()))
|
||||
Ok(result.to_basic_value_enum())
|
||||
} else {
|
||||
// For other operations, they are all elementwise operations.
|
||||
|
||||
@ -1594,14 +1594,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
op,
|
||||
(&Some(ty2_dtype), right_value),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, common_dtype)?;
|
||||
)?;
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.unwrap();
|
||||
Ok(Some(result.as_abi_value(ctx).into()))
|
||||
Ok(result.as_abi_value(ctx).into())
|
||||
}
|
||||
} else {
|
||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||
@ -1650,7 +1648,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
(&signature, fun_id),
|
||||
vec![(None, right_val.into())],
|
||||
)
|
||||
.map(|f| f.map(Into::into))
|
||||
.map(Option::unwrap)
|
||||
.map(BasicValueEnum::into)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1688,6 +1687,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
||||
(&right.custom, right_val),
|
||||
loc,
|
||||
)
|
||||
.map(|res| Some(res.into()))
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
|
||||
@ -1697,11 +1697,11 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
op: ast::Unaryop,
|
||||
operand: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let (ty, val) = operand;
|
||||
let ty = ctx.unifier.get_representative(ty.unwrap());
|
||||
|
||||
Ok(Some(if ty == ctx.primitives.bool {
|
||||
Ok(if ty == ctx.primitives.bool {
|
||||
let val = val.into_int_value();
|
||||
if op == ast::Unaryop::Not {
|
||||
let not = ctx
|
||||
@ -1722,7 +1722,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(),
|
||||
),
|
||||
)?
|
||||
.unwrap()
|
||||
}
|
||||
} else if [
|
||||
ctx.primitives.int32,
|
||||
@ -1791,16 +1790,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
ctx,
|
||||
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|
||||
|generator, ctx, scalar| {
|
||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))?
|
||||
.map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
|
||||
.unwrap()
|
||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))
|
||||
},
|
||||
)?;
|
||||
|
||||
mapped_ndarray.as_abi_value(ctx).into()
|
||||
} else {
|
||||
unimplemented!()
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for a unary operator expression.
|
||||
@ -1820,6 +1817,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
|
||||
};
|
||||
|
||||
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
|
||||
.map(|res| Some(res.into()))
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
|
||||
@ -1830,7 +1828,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
left: (Option<Type>, BasicValueEnum<'ctx>),
|
||||
ops: &[ast::Cmpop],
|
||||
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
debug_assert_eq!(comparators.len(), ops.len());
|
||||
|
||||
if comparators.len() == 1 {
|
||||
@ -1872,19 +1870,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
(Some(left_ty_dtype), left_scalar),
|
||||
&[op],
|
||||
&[(Some(right_ty_dtype), right_scalar)],
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.primitives.bool,
|
||||
)?;
|
||||
|
||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||
},
|
||||
)?;
|
||||
|
||||
return Ok(Some(result_ndarray.as_abi_value(ctx).into()));
|
||||
return Ok(result_ndarray.as_abi_value(ctx).into());
|
||||
}
|
||||
}
|
||||
|
||||
@ -2007,13 +1999,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
ctx,
|
||||
Unaryop::Not,
|
||||
(&Some(ctx.primitives.bool), result.into()),
|
||||
)
|
||||
.transpose()
|
||||
.unwrap()
|
||||
.and_then(|res| {
|
||||
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
||||
})?
|
||||
.into_int_value()
|
||||
)?.into_int_value()
|
||||
} else {
|
||||
result
|
||||
}
|
||||
@ -2116,9 +2102,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
&[Cmpop::Eq],
|
||||
&[(Some(right_elem_ty), right)],
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
|
||||
gen_if_callback(
|
||||
@ -2167,8 +2150,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
Unaryop::Not,
|
||||
(&Some(ctx.primitives.bool), acc.into()),
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?
|
||||
.into_int_value()
|
||||
} else {
|
||||
acc
|
||||
@ -2256,12 +2237,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
&[op],
|
||||
&[(Some(right_ty), right_elem)],
|
||||
)
|
||||
.transpose()
|
||||
.unwrap()
|
||||
.and_then(|v| {
|
||||
v.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
||||
})?
|
||||
.into_int_value();
|
||||
.map(BasicValueEnum::into_int_value)?;
|
||||
|
||||
Ok(ctx.builder.build_not(
|
||||
generator.bool_to_i1(ctx, cmp),
|
||||
@ -2301,14 +2277,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
generator,
|
||||
ctx,
|
||||
Unaryop::Not,
|
||||
(&Some(ctx.primitives.bool), cmp_phi.into())
|
||||
)
|
||||
.transpose()
|
||||
.unwrap()
|
||||
.and_then(|res| {
|
||||
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
||||
})?
|
||||
.into_int_value()
|
||||
(&Some(ctx.primitives.bool), cmp_phi.into()),
|
||||
)?.into_int_value()
|
||||
} else {
|
||||
cmp_phi
|
||||
}
|
||||
@ -2333,12 +2303,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
};
|
||||
|
||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||
})?;
|
||||
})?.unwrap();
|
||||
|
||||
Ok(Some(match cmp_val {
|
||||
Some(v) => v.into(),
|
||||
None => return Ok(None),
|
||||
}))
|
||||
Ok(cmp_val.into())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for a comparison operator expression.
|
||||
@ -2385,6 +2352,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||
ops,
|
||||
comparator_vals.as_slice(),
|
||||
)
|
||||
.map(|res| Some(res.into()))
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_expr`].
|
||||
|
@ -213,9 +213,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
||||
Binop::normal(Operator::Mult),
|
||||
(&Some(rhs_dtype), b_kj),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||
)?;
|
||||
|
||||
// dst_[...]ij += x
|
||||
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
||||
@ -226,9 +224,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
||||
Binop::normal(Operator::Add),
|
||||
(&Some(dst_dtype), x),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||
)?;
|
||||
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
||||
|
||||
Ok(())
|
||||
|
Loading…
Reference in New Issue
Block a user