[core] codegen/expr: Simplify gen_*_expr_with_values
return value
These functions always return `BasicValueEnum` because they operate on `BasicValueEnum`s, and they also always return a value.
This commit is contained in:
parent
0d8cb909dd
commit
c37c7e8975
@ -1319,7 +1319,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
op: Binop,
|
op: Binop,
|
||||||
right: (&Option<Type>, BasicValueEnum<'ctx>),
|
right: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||||
loc: Location,
|
loc: Location,
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
let (left_ty, left_val) = left;
|
let (left_ty, left_val) = left;
|
||||||
let (right_ty, right_val) = right;
|
let (right_ty, right_val) = right;
|
||||||
|
|
||||||
@ -1330,14 +1330,14 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
// which would be unchanged until further unification, which we would never do
|
// which would be unchanged until further unification, which we would never do
|
||||||
// when doing code generation for function instances
|
// when doing code generation for function instances
|
||||||
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
|
if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
|
||||||
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, true).into()))
|
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, true))
|
||||||
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
|
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
|
||||||
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, false).into()))
|
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, false))
|
||||||
} else if [Operator::LShift, Operator::RShift].contains(&op.base) {
|
} else if [Operator::LShift, Operator::RShift].contains(&op.base) {
|
||||||
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
|
let signed = [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1);
|
||||||
Ok(Some(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed).into()))
|
Ok(ctx.gen_int_ops(generator, op.base, left_val, right_val, signed))
|
||||||
} else if ty1 == ty2 && ctx.primitives.float == ty1 {
|
} else if ty1 == ty2 && ctx.primitives.float == ty1 {
|
||||||
Ok(Some(ctx.gen_float_ops(op.base, left_val, right_val).into()))
|
Ok(ctx.gen_float_ops(op.base, left_val, right_val))
|
||||||
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
|
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
|
||||||
// Pow is the only operator that would pass typecheck between float and int
|
// Pow is the only operator that would pass typecheck between float and int
|
||||||
assert_eq!(op.base, Operator::Pow);
|
assert_eq!(op.base, Operator::Pow);
|
||||||
@ -1347,7 +1347,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
right_val.into_int_value(),
|
right_val.into_int_value(),
|
||||||
Some("f_pow_i"),
|
Some("f_pow_i"),
|
||||||
);
|
);
|
||||||
Ok(Some(res.into()))
|
Ok(res.into())
|
||||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
||||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|
||||||
{
|
{
|
||||||
@ -1437,7 +1437,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
ctx.ctx.bool_type().const_zero(),
|
ctx.ctx.bool_type().const_zero(),
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(Some(new_list.as_abi_value(ctx).into()))
|
Ok(new_list.as_abi_value(ctx).into())
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator::Mult => {
|
Operator::Mult => {
|
||||||
@ -1524,7 +1524,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
llvm_usize.const_int(1, false),
|
llvm_usize.const_int(1, false),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Some(new_list.as_abi_value(ctx).into()))
|
Ok(new_list.as_abi_value(ctx).into())
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => todo!("Operator not supported"),
|
_ => todo!("Operator not supported"),
|
||||||
@ -1563,7 +1563,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
let result = left
|
let result = left
|
||||||
.matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out))
|
.matmul(generator, ctx, ty1, (ty2, right), (common_dtype, out))
|
||||||
.split_unsized(generator, ctx);
|
.split_unsized(generator, ctx);
|
||||||
Ok(Some(result.to_basic_value_enum().into()))
|
Ok(result.to_basic_value_enum())
|
||||||
} else {
|
} else {
|
||||||
// For other operations, they are all elementwise operations.
|
// For other operations, they are all elementwise operations.
|
||||||
|
|
||||||
@ -1594,14 +1594,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
op,
|
op,
|
||||||
(&Some(ty2_dtype), right_value),
|
(&Some(ty2_dtype), right_value),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
)?
|
)?;
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, common_dtype)?;
|
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
Ok(Some(result.as_abi_value(ctx).into()))
|
Ok(result.as_abi_value(ctx).into())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||||
@ -1650,7 +1648,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
(&signature, fun_id),
|
(&signature, fun_id),
|
||||||
vec![(None, right_val.into())],
|
vec![(None, right_val.into())],
|
||||||
)
|
)
|
||||||
.map(|f| f.map(Into::into))
|
.map(Option::unwrap)
|
||||||
|
.map(BasicValueEnum::into)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1688,6 +1687,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||||||
(&right.custom, right_val),
|
(&right.custom, right_val),
|
||||||
loc,
|
loc,
|
||||||
)
|
)
|
||||||
|
.map(|res| Some(res.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
|
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
|
||||||
@ -1697,11 +1697,11 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
op: ast::Unaryop,
|
op: ast::Unaryop,
|
||||||
operand: (&Option<Type>, BasicValueEnum<'ctx>),
|
operand: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
let (ty, val) = operand;
|
let (ty, val) = operand;
|
||||||
let ty = ctx.unifier.get_representative(ty.unwrap());
|
let ty = ctx.unifier.get_representative(ty.unwrap());
|
||||||
|
|
||||||
Ok(Some(if ty == ctx.primitives.bool {
|
Ok(if ty == ctx.primitives.bool {
|
||||||
let val = val.into_int_value();
|
let val = val.into_int_value();
|
||||||
if op == ast::Unaryop::Not {
|
if op == ast::Unaryop::Not {
|
||||||
let not = ctx
|
let not = ctx
|
||||||
@ -1722,7 +1722,6 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(),
|
ctx.builder.build_int_z_extend(val, llvm_i32, "").map(Into::into).unwrap(),
|
||||||
),
|
),
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
|
||||||
}
|
}
|
||||||
} else if [
|
} else if [
|
||||||
ctx.primitives.int32,
|
ctx.primitives.int32,
|
||||||
@ -1791,16 +1790,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
ctx,
|
ctx,
|
||||||
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|
||||||
|generator, ctx, scalar| {
|
|generator, ctx, scalar| {
|
||||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))?
|
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))
|
||||||
.map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
|
|
||||||
.unwrap()
|
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
mapped_ndarray.as_abi_value(ctx).into()
|
mapped_ndarray.as_abi_value(ctx).into()
|
||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for a unary operator expression.
|
/// Generates LLVM IR for a unary operator expression.
|
||||||
@ -1820,6 +1817,7 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
|
|||||||
};
|
};
|
||||||
|
|
||||||
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
|
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
|
||||||
|
.map(|res| Some(res.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
|
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
|
||||||
@ -1830,7 +1828,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
left: (Option<Type>, BasicValueEnum<'ctx>),
|
left: (Option<Type>, BasicValueEnum<'ctx>),
|
||||||
ops: &[ast::Cmpop],
|
ops: &[ast::Cmpop],
|
||||||
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
|
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
debug_assert_eq!(comparators.len(), ops.len());
|
debug_assert_eq!(comparators.len(), ops.len());
|
||||||
|
|
||||||
if comparators.len() == 1 {
|
if comparators.len() == 1 {
|
||||||
@ -1872,19 +1870,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
(Some(left_ty_dtype), left_scalar),
|
(Some(left_ty_dtype), left_scalar),
|
||||||
&[op],
|
&[op],
|
||||||
&[(Some(right_ty_dtype), right_scalar)],
|
&[(Some(right_ty_dtype), right_scalar)],
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
return Ok(Some(result_ndarray.as_abi_value(ctx).into()));
|
return Ok(result_ndarray.as_abi_value(ctx).into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2007,13 +1999,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
ctx,
|
ctx,
|
||||||
Unaryop::Not,
|
Unaryop::Not,
|
||||||
(&Some(ctx.primitives.bool), result.into()),
|
(&Some(ctx.primitives.bool), result.into()),
|
||||||
)
|
)?.into_int_value()
|
||||||
.transpose()
|
|
||||||
.unwrap()
|
|
||||||
.and_then(|res| {
|
|
||||||
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
|
||||||
})?
|
|
||||||
.into_int_value()
|
|
||||||
} else {
|
} else {
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
@ -2116,9 +2102,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
&[Cmpop::Eq],
|
&[Cmpop::Eq],
|
||||||
&[(Some(right_elem_ty), right)],
|
&[(Some(right_elem_ty), right)],
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
.into_int_value();
|
||||||
|
|
||||||
gen_if_callback(
|
gen_if_callback(
|
||||||
@ -2167,8 +2150,6 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
Unaryop::Not,
|
Unaryop::Not,
|
||||||
(&Some(ctx.primitives.bool), acc.into()),
|
(&Some(ctx.primitives.bool), acc.into()),
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?
|
|
||||||
.into_int_value()
|
.into_int_value()
|
||||||
} else {
|
} else {
|
||||||
acc
|
acc
|
||||||
@ -2256,12 +2237,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
&[op],
|
&[op],
|
||||||
&[(Some(right_ty), right_elem)],
|
&[(Some(right_ty), right_elem)],
|
||||||
)
|
)
|
||||||
.transpose()
|
.map(BasicValueEnum::into_int_value)?;
|
||||||
.unwrap()
|
|
||||||
.and_then(|v| {
|
|
||||||
v.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
|
||||||
})?
|
|
||||||
.into_int_value();
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_not(
|
Ok(ctx.builder.build_not(
|
||||||
generator.bool_to_i1(ctx, cmp),
|
generator.bool_to_i1(ctx, cmp),
|
||||||
@ -2301,14 +2277,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
Unaryop::Not,
|
Unaryop::Not,
|
||||||
(&Some(ctx.primitives.bool), cmp_phi.into())
|
(&Some(ctx.primitives.bool), cmp_phi.into()),
|
||||||
)
|
)?.into_int_value()
|
||||||
.transpose()
|
|
||||||
.unwrap()
|
|
||||||
.and_then(|res| {
|
|
||||||
res.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
|
||||||
})?
|
|
||||||
.into_int_value()
|
|
||||||
} else {
|
} else {
|
||||||
cmp_phi
|
cmp_phi
|
||||||
}
|
}
|
||||||
@ -2333,12 +2303,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
};
|
};
|
||||||
|
|
||||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||||
})?;
|
})?.unwrap();
|
||||||
|
|
||||||
Ok(Some(match cmp_val {
|
Ok(cmp_val.into())
|
||||||
Some(v) => v.into(),
|
|
||||||
None => return Ok(None),
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for a comparison operator expression.
|
/// Generates LLVM IR for a comparison operator expression.
|
||||||
@ -2385,6 +2352,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
|||||||
ops,
|
ops,
|
||||||
comparator_vals.as_slice(),
|
comparator_vals.as_slice(),
|
||||||
)
|
)
|
||||||
|
.map(|res| Some(res.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_expr`].
|
/// See [`CodeGenerator::gen_expr`].
|
||||||
|
@ -213,9 +213,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
|||||||
Binop::normal(Operator::Mult),
|
Binop::normal(Operator::Mult),
|
||||||
(&Some(rhs_dtype), b_kj),
|
(&Some(rhs_dtype), b_kj),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
)?
|
)?;
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
|
||||||
|
|
||||||
// dst_[...]ij += x
|
// dst_[...]ij += x
|
||||||
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
||||||
@ -226,9 +224,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
|||||||
Binop::normal(Operator::Add),
|
Binop::normal(Operator::Add),
|
||||||
(&Some(dst_dtype), x),
|
(&Some(dst_dtype), x),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
)?
|
)?;
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
|
||||||
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Loading…
Reference in New Issue
Block a user