diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 5efbe4a5..feec51ea 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -101,7 +101,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( .iter() .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { + if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) { ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() } else { ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() @@ -241,7 +241,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( .iter() .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { + if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) { ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() } else { ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() @@ -304,20 +304,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!(n_ty.is_integral(&mut ctx.unifier, &ctx.primitives)); - if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty)) - { + if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) { ctx.builder .build_signed_int_to_float(n, llvm_f64, "sitofp") .map(Into::into) @@ -331,7 +320,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); n.into() } @@ -373,7 +362,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( Ok(match n { BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); let val = llvm_intrinsics::call_float_round(ctx, n, None); ctx.builder @@ -417,7 +406,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( Ok(match n { BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); llvm_intrinsics::call_float_roundeven(ctx, n, None).into() } @@ -463,14 +452,10 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::IntValue(n) => { - debug_assert!([ - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); + debug_assert!( + n_ty.is_integral(&mut ctx.unifier, &ctx.primitives) + && n_ty.is_arithmetic(&mut ctx.unifier, &ctx.primitives) + ); ctx.builder .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) @@ -479,7 +464,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); ctx.builder .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME) @@ -528,7 +513,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( Ok(match n { BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); let val = llvm_intrinsics::call_float_floor(ctx, n, None); if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { @@ -578,7 +563,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( Ok(match n { BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); let val = llvm_intrinsics::call_float_ceil(ctx, n, None); if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { @@ -631,20 +616,9 @@ pub fn call_min<'ctx>( match (m, n) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty))); + debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives)); - if [ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty)) - { + if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) { llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into() } else { llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into() @@ -652,7 +626,7 @@ pub fn call_min<'ctx>( } (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { - debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); + debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() } @@ -675,16 +649,10 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( Ok(match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(a_ty, *ty))); + debug_assert!( + a_ty.is_integral(&mut ctx.unifier, &ctx.primitives) + || a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives) + ); a } @@ -761,22 +729,13 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives)); call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); + debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives)); call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } @@ -847,20 +806,9 @@ pub fn call_max<'ctx>( match (m, n) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty))); + debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives)); - if [ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty)) - { + if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) { llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into() } else { llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into() @@ -868,7 +816,7 @@ pub fn call_max<'ctx>( } (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { - debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); + debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into() } @@ -891,16 +839,10 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( Ok(match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(a_ty, *ty))); + debug_assert!( + a_ty.is_integral(&mut ctx.unifier, &ctx.primitives) + || a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives) + ); a } @@ -977,22 +919,13 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives)); call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); + debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives)); call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } @@ -1117,22 +1050,11 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( n, FN_NAME, &|_ctx, elem_ty| elem_ty, - &|_generator, ctx, val_ty, val| match val { + &|_, ctx, val_ty, val| match val { BasicValueEnum::IntValue(n) => Some({ - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(val_ty, *ty))); + debug_assert!(val_ty.is_integral(&mut ctx.unifier, &ctx.primitives)); - if [ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(val_ty, *ty)) - { + if val_ty.is_signed(&mut ctx.unifier, &ctx.primitives) { llvm_intrinsics::call_int_abs( ctx, n, @@ -1146,7 +1068,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( }), BasicValueEnum::FloatValue(n) => Some({ - debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float)); + debug_assert!(val_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() }), @@ -1431,8 +1353,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); + debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); extern_fns::call_atan2(ctx, x1, x2, None).into() } @@ -1498,8 +1420,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); + debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() } @@ -1565,8 +1487,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); + debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() } @@ -1632,8 +1554,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); + debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() } @@ -1699,7 +1621,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); + debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); extern_fns::call_ldexp(ctx, x1, x2, None).into() @@ -1755,8 +1677,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); + debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); extern_fns::call_hypot(ctx, x1, x2, None).into() } @@ -1822,8 +1744,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); + debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)); extern_fns::call_nextafter(ctx, x1, x2, None).into() } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index af2fd8de..bb4ab244 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -74,6 +74,34 @@ impl PrimitiveStore { _ => unreachable!(), } } + + /// Returns an iterator over all primitive types in this store. + fn iter(&self) -> impl Iterator { + self.into_iter() + } +} + +impl IntoIterator for &PrimitiveStore { + type Item = Type; + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + vec![ + self.int32, + self.int64, + self.uint32, + self.uint64, + self.float, + self.bool, + self.none, + self.range, + self.str, + self.exception, + self.option, + self.ndarray, + ] + .into_iter() + } } pub struct FunctionData { diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 6c2ffbc5..594ecce1 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -22,6 +22,16 @@ mod test; /// Handle for a type, implemented as a key in the unification table. pub type Type = UnificationKey; +/// Macro for generating functions related to type traits, e.g. whether the type is integral. +macro_rules! primitive_type_trait_fn { + ($id:ident, $( $matches:ident ),*) => { + #[must_use] + pub fn $id(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool { + [$(store.$matches,)*].into_iter().any(|ty| unifier.unioned(self, ty)) + } + }; +} + impl Type { /// Wrapper function for cleaner code so that we don't need to write this long pattern matching /// just to get the field `obj_id`. @@ -33,6 +43,17 @@ impl Type { None } } + + #[must_use] + pub fn is_primitive(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool { + store.into_iter().any(|ty| unifier.unioned(self, ty)) + } + + primitive_type_trait_fn!(is_integral, bool, int32, int64, uint32, uint64); + primitive_type_trait_fn!(is_floating_point, float); + primitive_type_trait_fn!(is_arithmetic, int32, int64, uint32, uint64, float); + primitive_type_trait_fn!(is_signed, int32, uint32, float); + primitive_type_trait_fn!(is_unsigned, uint32, uint64); } #[derive(Clone, Copy, PartialEq, Eq, Debug)]