diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 55622d76..e5e74ade 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,13 +1,16 @@ -use inkwell::types::BasicTypeEnum; -use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; +use inkwell::types::{BasicTypeEnum, IntType}; +use inkwell::values::{BasicValue, BasicValueEnum, FloatValue, IntValue, PointerValue}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; use crate::codegen::classes::{ NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; +use crate::codegen::llvm_intrinsics::call_float_rint; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; +use crate::codegen::structure::ndarray::mapping::starmap_scalars_array_like; +use crate::codegen::structure::ndarray::NDArrayObject; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; @@ -25,60 +28,74 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) - ) } +fn handle_cast_to_int_conversion<'ctx, G, F>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + (n_ty, n): (Type, BasicValueEnum<'ctx>), + func_name: &str, + int_type: Type, + handle_float: F, +) -> Result, String> +where + G: CodeGenerator + ?Sized, + F: FnOnce(&mut G, &mut CodeGenContext<'ctx, '_>, FloatValue<'ctx>) -> IntValue<'ctx>, +{ + let llvm_int_type = ctx.get_llvm_type(generator, int_type).into_int_type(); + split_scalar_or_ndarray(generator, ctx, n, n_ty) + .map(generator, ctx, |generator, ctx, _i, scalar| { + let int = match () { + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.bool) => { + // For booleans, simply extend its number of bits + let n = scalar.value.into_int_value(); + ctx.builder.build_int_z_extend(n, llvm_int_type, "").unwrap() + } + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { + // For floats, cast it + let n = scalar.value.into_float_value(); + handle_float(generator, ctx, n) + } + () if ctx.unifier.unioned_any( + scalar.dtype, + [ + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ], + ) => + { + // For int32, int64, uint32, uint64, sign-extend or truncate + let n = scalar.value.into_int_value(); + if n.get_type().get_bit_width() <= llvm_int_type.get_bit_width() { + ctx.builder.build_int_s_extend(n, llvm_int_type, "").unwrap() + } else { + ctx.builder.build_int_truncate(n, llvm_int_type, "").unwrap() + } + } + () => unsupported_type(ctx, func_name, &[scalar.dtype]), + }; + Ok(ScalarObject { dtype: int_type, value: int.as_basic_value_enum() }) + }) + .map(ScalarOrNDArray::to_basic_value_enum) +} + /// Invokes the `int32` builtin function. pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - (n_ty, n): (Type, BasicValueEnum<'ctx>), + n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - split_scalar_or_ndarray(generator, ctx, n, n_ty) - .map(generator, ctx, ctx.primitives.int32, |_generator, ctx, _i, scalar| { - match &*ctx.unifier.get_ty(scalar.dtype) { - TypeEnum::TObj { obj_id, .. } - if *obj_id == ctx.primitives.bool.obj_id(&ctx.unifier).unwrap() => - { - Ok(ctx - .builder - .build_int_z_extend(scalar.value.into_int_value(), llvm_i32, "zext") - .unwrap() - .as_basic_value_enum()) - } - TypeEnum::TObj { obj_id, .. } - if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => - { - Ok(scalar.value) - } - TypeEnum::TObj { obj_id, .. } - if *obj_id == ctx.primitives.int64.obj_id(&ctx.unifier).unwrap() => - { - Ok(ctx - .builder - .build_int_truncate(scalar.value.into_int_value(), llvm_i32, "trunc") - .unwrap() - .as_basic_value_enum()) - } - TypeEnum::TObj { obj_id, .. } - if *obj_id == ctx.primitives.float.obj_id(&ctx.unifier).unwrap() => - { - let to_int64 = ctx - .builder - .build_float_to_signed_int( - scalar.value.into_float_value(), - ctx.ctx.i64_type(), - "", - ) - .unwrap(); - Ok(ctx - .builder - .build_int_truncate(to_int64, llvm_i32, "conv") - .unwrap() - .as_basic_value_enum()) - } - _ => unsupported_type(ctx, "int32", &[scalar.dtype]), - } - }) - .map(ScalarOrNDArray::to_basic_value_enum) + handle_cast_to_int_conversion( + generator, + ctx, + n, + "int32", + ctx.primitives.int32, + |_generator, ctx, n| { + let n = ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); + ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap() + }, + ) } /// Invokes the `int64` builtin function. @@ -87,60 +104,16 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { - debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { - ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() - } else { - ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() - } - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - n.into() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - ctx.builder - .build_float_to_signed_int(n, ctx.ctx.i64_type(), "fptosi") - .map(Into::into) - .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "int64", &[n_ty]), - }) + handle_cast_to_int_conversion( + generator, + ctx, + n, + "int64", + ctx.primitives.int64, + |_generator, ctx, n| { + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap() + }, + ) } /// Invokes the `uint32` builtin function. @@ -149,76 +122,34 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - - ctx.builder.build_int_z_extend(n, llvm_i32, "zext").map(Into::into).unwrap() - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 32 => { - debug_assert!([ctx.primitives.int32, ctx.primitives.uint32,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - n.into() - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!( - ctx.unifier.unioned(n_ty, ctx.primitives.int64) - || ctx.unifier.unioned(n_ty, ctx.primitives.uint64) - ); - - ctx.builder.build_int_truncate(n, llvm_i32, "trunc").map(Into::into).unwrap() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - + handle_cast_to_int_conversion( + generator, + ctx, + n, + "uint32", + ctx.primitives.uint32, + |_generator, ctx, n| { let n_gez = ctx .builder .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); - let to_int32 = ctx.builder.build_float_to_signed_int(n, llvm_i32, "").unwrap(); + let to_int32 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap(); let to_uint64 = ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); ctx.builder .build_select( n_gez, - ctx.builder.build_int_truncate(to_uint64, llvm_i32, "").unwrap(), + ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(), to_int32, "conv", ) .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "uint32", &[n_ty]), - }) + .into_int_value() + }, + ) } /// Invokes the `uint64` builtin function. @@ -227,518 +158,356 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { - debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { - ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() - } else { - ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() - } - } - - BasicValueEnum::IntValue(n) if n.get_type().get_bit_width() == 64 => { - debug_assert!([ctx.primitives.int64, ctx.primitives.uint64,] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - n.into() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - + handle_cast_to_int_conversion( + generator, + ctx, + n, + "uint64", + ctx.primitives.uint64, + |_generator, ctx, n| { let val_gez = ctx .builder .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") .unwrap(); - let to_int64 = ctx.builder.build_float_to_signed_int(n, llvm_i64, "").unwrap(); - let to_uint64 = ctx.builder.build_float_to_unsigned_int(n, llvm_i64, "").unwrap(); + let to_int64 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); + let to_uint64 = + ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); - ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "uint64", &[n_ty]), - }) + ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value() + }, + ) } /// Invokes the `float` builtin function. pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty)) - { - ctx.builder - .build_signed_int_to_float(n, llvm_f64, "sitofp") - .map(Into::into) - .unwrap() - } else { - ctx.builder - .build_unsigned_int_to_float(n, llvm_f64, "uitofp") - .map(Into::into) - .unwrap() + split_scalar_or_ndarray(generator, ctx, n, n_ty) + .map(generator, ctx, ctx.primitives.float, |_generator, ctx, _i, scalar| { + Ok(match () { + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { + // Handle floats + scalar.value.into_float_value() + } + () if ctx.unifier.unioned_any( + scalar.dtype, + [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64], + ) => + { + // Handle signed ints and booleans (treating booleans as signed ints for convenience) + let n = scalar.value.into_int_value(); + ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap() + } + () if ctx + .unifier + .unioned_any(scalar.dtype, [ctx.primitives.uint32, ctx.primitives.uint64]) => + { + // Handle unsigned ints + let n = scalar.value.into_int_value(); + ctx.builder + .build_unsigned_int_to_float(n, llvm_f64, "uitofp") + .map(Into::into) + .unwrap() + } + () => unsupported_type(ctx, "float", &[n_ty]), } - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - n.into() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, "float", &[n_ty]), - }) -} - -/// Invokes the `round` builtin function. -pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), - ret_elem_ty: Type, -) -> Result, String> { - const FN_NAME: &str = "round"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let val = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder - .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) - .map(Into::into) - .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) -} - -/// Invokes the `np_round` builtin function. -pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_round"; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - llvm_intrinsics::call_float_rint(ctx, n, None).into() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) + .as_basic_value_enum()) + }) + .map(ScalarOrNDArray::to_basic_value_enum) } /// Invokes the `bool` builtin function. pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "bool"; - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - - Ok(match n { - BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); - - n.into() - } - - BasicValueEnum::IntValue(n) => { - debug_assert!([ - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(n_ty, *ty))); - - ctx.builder - .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) - .map(Into::into) - .unwrap() - } - - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - ctx.builder - .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME) - .map(Into::into) - .unwrap() - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; - - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) + split_scalar_or_ndarray(generator, ctx, n, n_ty) + .map(generator, ctx, ctx.primitives.float, |_generator, ctx, _i, scalar| { + Ok(match () { + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.bool) => { + // Handle booleans + scalar.value.into_int_value() + } + () if ctx.unifier.unioned_any( + scalar.dtype, + [ + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ], + ) => + { + // Handle signed/unsigned ints + let n = scalar.value.into_int_value(); + ctx.builder + .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) + .unwrap() + } + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { + // Handle floats + let f64_type = ctx.ctx.f64_type(); + let n = scalar.value.into_float_value(); + ctx.builder + .build_float_compare(FloatPredicate::UNE, n, f64_type.const_zero(), FN_NAME) + .unwrap() + } + () => unsupported_type(ctx, FN_NAME, &[n_ty]), + } + .as_basic_value_enum()) + }) + .map(ScalarOrNDArray::to_basic_value_enum) } -/// Invokes the `floor` builtin function. -pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( +/// Invokes the `round` builtin function. +pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), - ret_elem_ty: Type, + (n_ty, n): (Type, BasicValueEnum<'ctx>), + ret_int_ty: Type, ) -> Result, String> { - const FN_NAME: &str = "floor"; + const FN_NAME: &str = "round"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let ret_int_ty_llvm = ctx.get_llvm_abi_type(generator, ret_int_ty).into_int_type(); - let (n_ty, n) = n; - let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let val = llvm_intrinsics::call_float_floor(ctx, n, None); - if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) - .map(Into::into) - .unwrap() - } else { - val.into() + split_scalar_or_ndarray(generator, ctx, n, n_ty) + .map(generator, ctx, ret_int_ty, |_generator, ctx, _i, scalar| { + Ok(match () { + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { + let n = scalar.value.into_float_value(); + let n = llvm_intrinsics::call_float_round(ctx, n, None); + ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, FN_NAME).unwrap() + } + () => unsupported_type(ctx, FN_NAME, &[n_ty]), } - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) + .as_basic_value_enum()) + }) + .map(ScalarOrNDArray::to_basic_value_enum) } -/// Invokes the `ceil` builtin function. -pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( +/// Invokes the `np_round` builtin function. +pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), - ret_elem_ty: Type, + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "ceil"; + const FN_NAME: &str = "np_round"; - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; - let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); - - Ok(match n { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - - let val = llvm_intrinsics::call_float_ceil(ctx, n, None); - if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) - .map(Into::into) - .unwrap() - } else { - val.into() - } - } - - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); - - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() - } - - _ => unsupported_type(ctx, FN_NAME, &[n_ty]), - }) + split_scalar_or_ndarray(generator, ctx, n, n_ty) + .map(generator, ctx, ctx.primitives.float, |_generator, ctx, _i, scalar| { + Ok(match () { + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { + let n = scalar.value.into_float_value(); + call_float_rint(ctx, n, None).as_basic_value_enum() + } + () => unsupported_type(ctx, FN_NAME, &[n_ty]), + }) + }) + .map(ScalarOrNDArray::to_basic_value_enum) } -/// Invokes the `min` builtin function. -pub fn call_min<'ctx>( +#[derive(Debug, Clone, Copy)] +pub enum CeilOrFloor { + Ceil, + Floor, +} + +/// Invokes the `ceil`/`floor` builtin function. +pub fn call_ceil_or_floor<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - m: (Type, BasicValueEnum<'ctx>), - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), + kind: CeilOrFloor, + ret_elem_ty: Type, // Can be float or int/bool +) -> Result, String> { + // TODO: Bad Rust tuple type inference leads to this + let fn_name = match kind { + CeilOrFloor::Ceil => "ceil", + CeilOrFloor::Floor => "floor", + }; + let function = match kind { + CeilOrFloor::Ceil => llvm_intrinsics::call_float_ceil, + CeilOrFloor::Floor => llvm_intrinsics::call_float_floor, + }; + + split_scalar_or_ndarray(generator, ctx, n, n_ty) + .map(generator, ctx, ctx.primitives.float, |generator, ctx, _i, scalar| { + Ok(match () { + () if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) => { + let n = scalar.value.into_float_value(); + let n = function(ctx, n, None); + + if ctx.unifier.unioned(ret_elem_ty, ctx.primitives.float) { + // If `ret_elem_ty` is a float, return the floored float directly + n.as_basic_value_enum() + } else { + // Otherwise ret_elem_ty must be IntType, cast the floored float to int + let ret_elem_ty_llvm = + ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); + ctx.builder + .build_float_to_signed_int(n, ret_elem_ty_llvm, fn_name) + .unwrap() + .as_basic_value_enum() + } + } + () => unsupported_type(ctx, fn_name, &[n_ty]), + }) + }) + .map(ScalarOrNDArray::to_basic_value_enum) +} + +#[derive(Debug, Clone, Copy)] +pub enum MinOrMax { + Min, + Max, +} + +/// Invokes the `min`/`max` builtin function. +pub fn call_min_or_max<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + (m_ty, m): (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), + kind: MinOrMax, ) -> BasicValueEnum<'ctx> { - const FN_NAME: &str = "min"; - - let (m_ty, m) = m; - let (n_ty, n) = n; + // TODO: Bad Rust type inference leads to this + let fn_name = match kind { + MinOrMax::Min => "min", + MinOrMax::Max => "max", + }; + let signedfn = match kind { + MinOrMax::Min => llvm_intrinsics::call_int_smin, + MinOrMax::Max => llvm_intrinsics::call_int_smax, + }; + let unsignedfn = match kind { + MinOrMax::Min => llvm_intrinsics::call_int_umin, + MinOrMax::Max => llvm_intrinsics::call_int_umax, + }; + let floatfn = match kind { + MinOrMax::Min => llvm_intrinsics::call_float_minnum, + MinOrMax::Max => llvm_intrinsics::call_float_maxnum, + }; let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { m_ty } else { - unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]) + unsupported_type(ctx, fn_name, &[m_ty, n_ty]) }; - match (m, n) { - (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty))); - - if [ctx.primitives.int32, ctx.primitives.int64] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty, *ty)) - { - llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into() - } else { - llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into() - } + match () { + () if ctx.unifier.unioned_any( + common_ty, + [ctx.primitives.bool, ctx.primitives.uint32, ctx.primitives.uint64], + ) => + { + // Handle unsigned ints and booleans (treating booleans as unsigned ints) + unsignedfn(ctx, m.into_int_value(), n.into_int_value(), Some(fn_name)).into() } - - (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { - debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); - - llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() + () if ctx.unifier.unioned_any(common_ty, [ctx.primitives.int32, ctx.primitives.int64]) => { + // Handle signed ints + signedfn(ctx, m.into_int_value(), n.into_int_value(), Some(fn_name)).into() } - - _ => unsupported_type(ctx, FN_NAME, &[m_ty, n_ty]), + () if ctx.unifier.unioned(common_ty, ctx.primitives.float) => { + // Handle floats + floatfn(ctx, m.into_float_value(), n.into_float_value(), Some(fn_name)).into() + } + () => unsupported_type(ctx, fn_name, &[m_ty, n_ty]), } } -/// Invokes the `np_minimum` builtin function. -pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( +/// Invokes the `np_minimum`/`np_maximum` builtin function. +pub fn call_numpy_minimum_or_maximum<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), + kind: MinOrMax, ) -> Result, String> { - const FN_NAME: &str = "np_minimum"; + let fn_name = match kind { + MinOrMax::Min => "np_minimum", + MinOrMax::Max => "np_maximum", + }; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + // starmap_scalars_array_like(generator, ctx, inputs, ret_dtype, mapping) - let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; + todo!() - Ok(match (x1, x2) { - (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!([ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.uint32, - ctx.primitives.int64, - ctx.primitives.uint64, - ctx.primitives.float, - ] - .iter() - .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + // let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; - call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - } + // Ok(match (x1, x2) { + // (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { + // debug_assert!([ + // ctx.primitives.bool, + // ctx.primitives.int32, + // ctx.primitives.uint32, + // ctx.primitives.int64, + // ctx.primitives.uint64, + // ctx.primitives.float, + // ] + // .iter() + // .any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); + // call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + // } - call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) - } + // (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + // debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + // call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + // } - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + // (x1, x2) + // if [&x1_ty, &x2_ty].into_iter().any(|ty| { + // ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) + // }) => + // { + // let is_ndarray1 = + // x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + // let is_ndarray2 = + // x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + // let dtype = if is_ndarray1 && is_ndarray2 { + // let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + // let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; + // debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; + // ndarray_dtype1 + // } else if is_ndarray1 { + // unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + // } else if is_ndarray2 { + // unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + // } else { + // unreachable!() + // }; - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } + // let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; + // let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + // numpy::ndarray_elementwise_binop_impl( + // generator, + // ctx, + // dtype, + // None, + // (x1, !is_ndarray1), + // (x2, !is_ndarray2), + // |generator, ctx, (lhs, rhs)| { + // call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + // }, + // )? + // .as_base_value() + // .into() + // } + + // _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + // }) } /// Invokes the `max` builtin function. diff --git a/nac3core/src/codegen/model/util.rs b/nac3core/src/codegen/model/util.rs index 19f29762..f59f6950 100644 --- a/nac3core/src/codegen/model/util.rs +++ b/nac3core/src/codegen/model/util.rs @@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ? /// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions. /// The [`IntKind`] is automatically inferred. -pub fn gen_for_model_auto<'ctx, 'a, G, F, I>( +pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, start: Int<'ctx, I>, stop: Int<'ctx, I>, step: Int<'ctx, I>, body: F, -) -> Result<(), String> +) -> Result where G: CodeGenerator + ?Sized, F: FnOnce( @@ -42,7 +42,7 @@ where &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks<'ctx>, Int<'ctx, I>, - ) -> Result<(), String>, + ) -> Result, I: IntKind<'ctx> + Default, { let int_model = IntModel(I::default()); diff --git a/nac3core/src/codegen/structure/ndarray/mapping.rs b/nac3core/src/codegen/structure/ndarray/mapping.rs index 1c28b734..84d835cc 100644 --- a/nac3core/src/codegen/structure/ndarray/mapping.rs +++ b/nac3core/src/codegen/structure/ndarray/mapping.rs @@ -1,4 +1,3 @@ -use inkwell::values::BasicValueEnum; use itertools::Itertools; use util::gen_for_model_auto; @@ -19,7 +18,6 @@ pub fn starmap_scalars_array_like<'ctx, 'a, F, G>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, inputs: &Vec>, - ret_dtype: Type, mapping: F, ) -> Result, String> where @@ -28,7 +26,7 @@ where &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, &Vec>, - ) -> Result, String>, + ) -> Result, String>, G: CodeGenerator + ?Sized, { assert!(!inputs.is_empty()); @@ -44,9 +42,9 @@ where // When inputs are all scalars, return a ScalarObject back let i = sizet_model.const_0(generator, ctx.ctx); - let ret = mapping(generator, ctx, i, &scalars)?; - Ok(ScalarOrNDArray::Scalar(ScalarObject { value: ret, dtype: ret_dtype })) + let scalar = mapping(generator, ctx, i, &scalars)?; + Ok(ScalarOrNDArray::Scalar(scalar)) } None => { // When not all inputs are scalars, promote all non-ndarray inputs @@ -57,22 +55,12 @@ where let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays); - let mapped_ndarray = NDArrayObject::alloca_uninitialized( - generator, - ctx, - ret_dtype, - broadcast_result.ndims, - "mapped_ndarray", - ); - mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); - mapped_ndarray.create_data(generator, ctx); - let start = sizet_model.const_0(generator, ctx.ctx); - let stop = mapped_ndarray.size(generator, ctx); + let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`. let step = sizet_model.const_1(generator, ctx.ctx); // Map element-wise and store results into `mapped_ndarray`. - gen_for_model_auto( + let mapped_ndarray = gen_for_model_auto( generator, ctx, start, @@ -89,12 +77,26 @@ where .collect_vec(); let ret = mapping(generator, ctx, i, &elements)?; + + // It might look weird but it is perfectly fine putting the allocation codegen + // here within `for`. + // The reason for doing this is to get the `dtype` out of `ret`, which is only + // available after running `mapping`. + let mapped_ndarray = NDArrayObject::alloca_uninitialized( + generator, + ctx, + ret.dtype, + broadcast_result.ndims, + "mapped_ndarray", + ); + mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); + mapped_ndarray.create_data(generator, ctx); + let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret"); - ctx.builder.build_store(pret, ret).unwrap(); - Ok(()) + ctx.builder.build_store(pret, ret.value).unwrap(); + Ok(mapped_ndarray) }, )?; - Ok(ScalarOrNDArray::NDArray(mapped_ndarray)) } } @@ -105,7 +107,6 @@ impl<'ctx> ScalarObject<'ctx> { &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, - ret_dtype: Type, mapping: F, ) -> Result where @@ -114,14 +115,13 @@ impl<'ctx> ScalarObject<'ctx> { &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, ScalarObject<'ctx>, - ) -> Result, String>, + ) -> Result, String>, G: CodeGenerator + ?Sized, { let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like( generator, ctx, &vec![ScalarOrNDArray::Scalar(*self)], - ret_dtype, |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), )? else { @@ -136,7 +136,6 @@ impl<'ctx> NDArrayObject<'ctx> { &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, - ret_dtype: Type, mapping: F, ) -> Result where @@ -145,14 +144,13 @@ impl<'ctx> NDArrayObject<'ctx> { &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, ScalarObject<'ctx>, - ) -> Result, String>, + ) -> Result, String>, G: CodeGenerator + ?Sized, { let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like( generator, ctx, &vec![ScalarOrNDArray::NDArray(*self)], - ret_dtype, |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), )? else { @@ -176,15 +174,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> { &mut CodeGenContext<'ctx, 'a>, Int<'ctx, SizeT>, ScalarObject<'ctx>, - ) -> Result, String>, + ) -> Result, String>, G: CodeGenerator + ?Sized, { match self { - ScalarOrNDArray::Scalar(scalar) => { - scalar.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::Scalar) - } + ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like( + generator, + ctx, + &vec![ScalarOrNDArray::Scalar(*scalar)], + |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), + ), ScalarOrNDArray::NDArray(ndarray) => { - ndarray.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::NDArray) + ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray) } } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index c82b9983..1ed3628c 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> { let ret_elem_ty = size_variant.of_int(&ctx.primitives); let func = match kind { Kind::Ceil => builtin_fns::call_ceil, - Kind::Floor => builtin_fns::call_floor, + Kind::Floor => builtin_fns::call_ceil_or_floor, }; Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) }), @@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> { let func = match prim { PrimDef::FunNpCeil => builtin_fns::call_ceil, - PrimDef::FunNpFloor => builtin_fns::call_floor, + PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor, _ => unreachable!(), }; Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 845e8406..032bd376 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -336,6 +336,14 @@ impl Unifier { self.unification_table.unioned(a, b) } + /// Determine if a type unions with a type in `tys`. + pub fn unioned_any(&mut self, a: Type, tys: I) -> bool + where + I: IntoIterator, + { + tys.into_iter().any(|ty| self.unioned(a, ty)) + } + pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier { let lock = unifier.lock().unwrap(); Unifier {