forked from M-Labs/nac3
1
0
Fork 0

core/typedef: Add type trait functions to Type

This commit is contained in:
David Mak 2024-06-25 16:50:44 +08:00
parent c78accce70
commit 10a88e1799
3 changed files with 98 additions and 127 deletions

View File

@ -101,7 +101,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
.iter() .iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty))); .any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
} else { } else {
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
@ -241,7 +241,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
.iter() .iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty))); .any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
} else { } else {
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
@ -304,20 +304,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => {
debug_assert!([ debug_assert!(n_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64] if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty))
{
ctx.builder ctx.builder
.build_signed_int_to_float(n, llvm_f64, "sitofp") .build_signed_int_to_float(n, llvm_f64, "sitofp")
.map(Into::into) .map(Into::into)
@ -331,7 +320,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
} }
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
n.into() n.into()
} }
@ -373,7 +362,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_round(ctx, n, None); let val = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder ctx.builder
@ -417,7 +406,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_roundeven(ctx, n, None).into() llvm_intrinsics::call_float_roundeven(ctx, n, None).into()
} }
@ -463,14 +452,10 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
} }
BasicValueEnum::IntValue(n) => { BasicValueEnum::IntValue(n) => {
debug_assert!([ debug_assert!(
ctx.primitives.int32, n_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.uint32, && n_ty.is_arithmetic(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.int64, );
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
ctx.builder ctx.builder
.build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME)
@ -479,7 +464,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
} }
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
ctx.builder ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME) .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME)
@ -528,7 +513,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_floor(ctx, n, None); let val = llvm_intrinsics::call_float_floor(ctx, n, None);
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
@ -578,7 +563,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_ceil(ctx, n, None); let val = llvm_intrinsics::call_float_ceil(ctx, n, None);
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
@ -631,20 +616,9 @@ pub fn call_min<'ctx>(
match (m, n) { match (m, n) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
debug_assert!([ debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
if [ctx.primitives.int32, ctx.primitives.int64] if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
{
llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into()
} else { } else {
llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into()
@ -652,7 +626,7 @@ pub fn call_min<'ctx>(
} }
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into()
} }
@ -675,16 +649,10 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
Ok(match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([ debug_assert!(
ctx.primitives.bool, a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.int32, || a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.uint32, );
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a a
} }
@ -761,22 +729,13 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([ debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
@ -847,20 +806,9 @@ pub fn call_max<'ctx>(
match (m, n) { match (m, n) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
debug_assert!([ debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
if [ctx.primitives.int32, ctx.primitives.int64] if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
{
llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into()
} else { } else {
llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into()
@ -868,7 +816,7 @@ pub fn call_max<'ctx>(
} }
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into()
} }
@ -891,16 +839,10 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
Ok(match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([ debug_assert!(
ctx.primitives.bool, a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.int32, || a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.uint32, );
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a a
} }
@ -977,22 +919,13 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([ debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
@ -1117,22 +1050,11 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
n, n,
FN_NAME, FN_NAME,
&|_ctx, elem_ty| elem_ty, &|_ctx, elem_ty| elem_ty,
&|_generator, ctx, val_ty, val| match val { &|_, ctx, val_ty, val| match val {
BasicValueEnum::IntValue(n) => Some({ BasicValueEnum::IntValue(n) => Some({
debug_assert!([ debug_assert!(val_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(val_ty, *ty)));
if [ctx.primitives.int32, ctx.primitives.int64] if val_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(val_ty, *ty))
{
llvm_intrinsics::call_int_abs( llvm_intrinsics::call_int_abs(
ctx, ctx,
n, n,
@ -1146,7 +1068,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
}), }),
BasicValueEnum::FloatValue(n) => Some({ BasicValueEnum::FloatValue(n) => Some({
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float)); debug_assert!(val_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
}), }),
@ -1431,8 +1353,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_atan2(ctx, x1, x2, None).into() extern_fns::call_atan2(ctx, x1, x2, None).into()
} }
@ -1498,8 +1420,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()
} }
@ -1565,8 +1487,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()
} }
@ -1632,8 +1554,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()
} }
@ -1699,7 +1621,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32));
extern_fns::call_ldexp(ctx, x1, x2, None).into() extern_fns::call_ldexp(ctx, x1, x2, None).into()
@ -1755,8 +1677,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_hypot(ctx, x1, x2, None).into() extern_fns::call_hypot(ctx, x1, x2, None).into()
} }
@ -1822,8 +1744,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_nextafter(ctx, x1, x2, None).into() extern_fns::call_nextafter(ctx, x1, x2, None).into()
} }

View File

@ -74,6 +74,34 @@ impl PrimitiveStore {
_ => unreachable!(), _ => unreachable!(),
} }
} }
/// Returns an iterator over all primitive types in this store.
fn iter(&self) -> impl Iterator<Item = Type> {
self.into_iter()
}
}
impl IntoIterator for &PrimitiveStore {
type Item = Type;
type IntoIter = <Vec<Type> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
vec![
self.int32,
self.int64,
self.uint32,
self.uint64,
self.float,
self.bool,
self.none,
self.range,
self.str,
self.exception,
self.option,
self.ndarray,
]
.into_iter()
}
} }
pub struct FunctionData { pub struct FunctionData {

View File

@ -22,6 +22,16 @@ mod test;
/// Handle for a type, implemented as a key in the unification table. /// Handle for a type, implemented as a key in the unification table.
pub type Type = UnificationKey; pub type Type = UnificationKey;
/// Macro for generating functions related to type traits, e.g. whether the type is integral.
macro_rules! primitive_type_trait_fn {
($id:ident, $( $matches:ident ),*) => {
#[must_use]
pub fn $id(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
[$(store.$matches,)*].into_iter().any(|ty| unifier.unioned(self, ty))
}
};
}
impl Type { impl Type {
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching /// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// just to get the field `obj_id`. /// just to get the field `obj_id`.
@ -33,6 +43,17 @@ impl Type {
None None
} }
} }
#[must_use]
pub fn is_primitive(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
store.into_iter().any(|ty| unifier.unioned(self, ty))
}
primitive_type_trait_fn!(is_integral, bool, int32, int64, uint32, uint64);
primitive_type_trait_fn!(is_floating_point, float);
primitive_type_trait_fn!(is_arithmetic, int32, int64, uint32, uint64, float);
primitive_type_trait_fn!(is_signed, int32, uint32, float);
primitive_type_trait_fn!(is_unsigned, uint32, uint64);
} }
#[derive(Clone, Copy, PartialEq, Eq, Debug)] #[derive(Clone, Copy, PartialEq, Eq, Debug)]