1
0
forked from M-Labs/nac3

core/typedef: Add type trait functions to Type

This commit is contained in:
David Mak 2024-06-25 16:50:44 +08:00
parent c78accce70
commit 10a88e1799
3 changed files with 98 additions and 127 deletions

View File

@ -101,7 +101,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) {
if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
} else {
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
@ -241,7 +241,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) {
if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
} else {
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
@ -304,20 +304,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
debug_assert!(n_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64]
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty))
{
if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
ctx.builder
.build_signed_int_to_float(n, llvm_f64, "sitofp")
.map(Into::into)
@ -331,7 +320,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
}
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
n.into()
}
@ -373,7 +362,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder
@ -417,7 +406,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_roundeven(ctx, n, None).into()
}
@ -463,14 +452,10 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
}
BasicValueEnum::IntValue(n) => {
debug_assert!([
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
debug_assert!(
n_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
&& n_ty.is_arithmetic(&mut ctx.unifier, &ctx.primitives)
);
ctx.builder
.build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME)
@ -479,7 +464,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
}
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME)
@ -528,7 +513,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_floor(ctx, n, None);
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
@ -578,7 +563,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_ceil(ctx, n, None);
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
@ -631,20 +616,9 @@ pub fn call_min<'ctx>(
match (m, n) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
if [ctx.primitives.int32, ctx.primitives.int64]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
{
if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into()
} else {
llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into()
@ -652,7 +626,7 @@ pub fn call_min<'ctx>(
}
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float));
debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into()
}
@ -675,16 +649,10 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
debug_assert!(
a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
|| a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
);
a
}
@ -761,22 +729,13 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
@ -847,20 +806,9 @@ pub fn call_max<'ctx>(
match (m, n) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
if [ctx.primitives.int32, ctx.primitives.int64]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
{
if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into()
} else {
llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into()
@ -868,7 +816,7 @@ pub fn call_max<'ctx>(
}
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float));
debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into()
}
@ -891,16 +839,10 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
debug_assert!(
a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
|| a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
);
a
}
@ -977,22 +919,13 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
@ -1117,22 +1050,11 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
n,
FN_NAME,
&|_ctx, elem_ty| elem_ty,
&|_generator, ctx, val_ty, val| match val {
&|_, ctx, val_ty, val| match val {
BasicValueEnum::IntValue(n) => Some({
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(val_ty, *ty)));
debug_assert!(val_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
if [ctx.primitives.int32, ctx.primitives.int64]
.iter()
.any(|ty| ctx.unifier.unioned(val_ty, *ty))
{
if val_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
llvm_intrinsics::call_int_abs(
ctx,
n,
@ -1146,7 +1068,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
}),
BasicValueEnum::FloatValue(n) => Some({
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float));
debug_assert!(val_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
}),
@ -1431,8 +1353,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_atan2(ctx, x1, x2, None).into()
}
@ -1498,8 +1420,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()
}
@ -1565,8 +1487,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()
}
@ -1632,8 +1554,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()
}
@ -1699,7 +1621,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32));
extern_fns::call_ldexp(ctx, x1, x2, None).into()
@ -1755,8 +1677,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_hypot(ctx, x1, x2, None).into()
}
@ -1822,8 +1744,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_nextafter(ctx, x1, x2, None).into()
}

View File

@ -74,6 +74,34 @@ impl PrimitiveStore {
_ => unreachable!(),
}
}
/// Returns an iterator over all primitive types in this store.
fn iter(&self) -> impl Iterator<Item = Type> {
self.into_iter()
}
}
impl IntoIterator for &PrimitiveStore {
type Item = Type;
type IntoIter = <Vec<Type> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
vec![
self.int32,
self.int64,
self.uint32,
self.uint64,
self.float,
self.bool,
self.none,
self.range,
self.str,
self.exception,
self.option,
self.ndarray,
]
.into_iter()
}
}
pub struct FunctionData {

View File

@ -22,6 +22,16 @@ mod test;
/// Handle for a type, implemented as a key in the unification table.
pub type Type = UnificationKey;
/// Macro for generating functions related to type traits, e.g. whether the type is integral.
macro_rules! primitive_type_trait_fn {
($id:ident, $( $matches:ident ),*) => {
#[must_use]
pub fn $id(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
[$(store.$matches,)*].into_iter().any(|ty| unifier.unioned(self, ty))
}
};
}
impl Type {
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// just to get the field `obj_id`.
@ -33,6 +43,17 @@ impl Type {
None
}
}
#[must_use]
pub fn is_primitive(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
store.into_iter().any(|ty| unifier.unioned(self, ty))
}
primitive_type_trait_fn!(is_integral, bool, int32, int64, uint32, uint64);
primitive_type_trait_fn!(is_floating_point, float);
primitive_type_trait_fn!(is_arithmetic, int32, int64, uint32, uint64, float);
primitive_type_trait_fn!(is_signed, int32, uint32, float);
primitive_type_trait_fn!(is_unsigned, uint32, uint64);
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]