diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 58b37ad..4c26b43 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -10,6 +10,8 @@ use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::Type; +// TODO: Rename ret_ty to ret_elem_ty or similar + /// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// /// The generated message will contain the function name and the name of the unsupported type. @@ -567,55 +569,117 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( } /// Invokes the `floor` builtin function. -pub fn call_floor<'ctx>( +pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), - llvm_ret_ty: BasicTypeEnum<'ctx>, -) -> BasicValueEnum<'ctx> { + n: (Type, BasicValueEnum<'ctx>), + ret_ty: Type, +) -> Result, String> { const FN_NAME: &str = "floor"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let val = llvm_intrinsics::call_float_floor(ctx, n, None); - match llvm_ret_ty { - _ if llvm_ret_ty == val.get_type().into() => val.into(), + Ok(match n.get_type() { + BasicTypeEnum::IntType(_) + | BasicTypeEnum::FloatType(_) => { + debug_assert!([ + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ctx.primitives.float, + ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); - BasicTypeEnum::IntType(_) => { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME) - .map(Into::into) - .unwrap() + let val = llvm_intrinsics::call_float_floor(ctx, n.into_float_value(), None); + if llvm_ret_ty.is_int_type() { + ctx.builder + .build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME) + .map(Into::into) + .unwrap() + } else { + val.into() + } + } + + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_ty, + None, + NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_floor(generator, ctx, (elem_ty, val), ret_ty) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `ceil` builtin function. -pub fn call_ceil<'ctx>( +pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), - llvm_ret_ty: BasicTypeEnum<'ctx>, -) -> BasicValueEnum<'ctx> { + n: (Type, BasicValueEnum<'ctx>), + ret_ty: Type, +) -> Result, String> { const FN_NAME: &str = "ceil"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); + let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let val = llvm_intrinsics::call_float_ceil(ctx, n, None); - match llvm_ret_ty { - _ if llvm_ret_ty == val.get_type().into() => val.into(), + Ok(match n.get_type() { + BasicTypeEnum::IntType(_) + | BasicTypeEnum::FloatType(_) => { + debug_assert!([ + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ctx.primitives.float, + ].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty))); - BasicTypeEnum::IntType(_) => { - ctx.builder - .build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME) - .map(Into::into) - .unwrap() + let val = llvm_intrinsics::call_float_ceil(ctx, n.into_float_value(), None); + if llvm_ret_ty.is_int_type() { + ctx.builder + .build_float_to_signed_int(val, llvm_ret_ty.into_int_type(), FN_NAME) + .map(Into::into) + .unwrap() + } else { + val.into() + } + } + + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_ty, + None, + NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_floor(generator, ctx, (elem_ty, val), ret_ty) + }, + )?; + + ndarray.as_ptr_value().into() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]) - } + }) } /// Invokes the `min` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 4ea6da3..c64a196 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -363,6 +363,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); + // TODO: Double check (T | ndarray[T]).to_basic_value_enum converts correctly + // TODO: Directly obtain Type instance after get_fresh_var_with_ let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( PRIMITIVE_DEF_IDS.int32, @@ -1025,102 +1027,172 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built )))), loc: None, })), - create_fn_by_codegen( - unifier, - &var_map, - "floor", - int32, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i32 = ctx.ctx.i32_type(); + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int32, ndarray_int32], + Some("R".into()), + None, + ); - Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i32.into()))) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "floor64", - int64, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i64 = ctx.ctx.i64_type(); + create_fn_by_codegen( + unifier, + &var_map, + "floor", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + }), + ) + }, + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), llvm_i64.into()))) - }), - ), + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int64, ndarray_int64], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &var_map, + "floor64", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + }), + ) + }, create_fn_by_codegen( unifier, &var_map, "np_floor", - float, - &[(float, "n")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - Ok(Some(builtin_fns::call_floor(ctx, (arg_ty, arg), arg.get_type().into()))) + Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) }), ), - create_fn_by_codegen( - unifier, - &var_map, - "ceil", - int32, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i32 = ctx.ctx.i32_type(); + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int32, ndarray_int32], + Some("R".into()), + None, + ); - Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i32.into()))) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "ceil64", - int64, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i64 = ctx.ctx.i64_type(); + create_fn_by_codegen( + unifier, + &var_map, + "ceil", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + }), + ) + }, + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), llvm_i64.into()))) - }), - ), + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int64, ndarray_int64], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &var_map, + "ceil64", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + }), + ) + }, create_fn_by_codegen( unifier, &var_map, "np_ceil", - float, - &[(float, "n")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - Ok(Some(builtin_fns::call_ceil(ctx, (arg_ty, arg), arg.get_type().into()))) + Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) }), ), Arc::new(RwLock::new({ diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 9511329..f38e70b 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -79,6 +79,18 @@ def round_away_zero(x): else: return math.ceil(x - 0.5) +def _floor(x): + if isinstance(x, np.ndarray): + return np.vectorize(_floor)(x) + else: + return math.floor(x) + +def _ceil(x): + if isinstance(x, np.ndarray): + return np.vectorize(_ceil)(x) + else: + return math.ceil(x) + def patch(module): def dbl_nan(): return np.nan @@ -142,11 +154,11 @@ def patch(module): module.round = round_away_zero module.round64 = round_away_zero module.np_round = np.round - module.floor = math.floor - module.floor64 = math.floor + module.floor = _floor + module.floor64 = _floor module.np_floor = np.floor - module.ceil = math.ceil - module.ceil64 = math.ceil + module.ceil = _ceil + module.ceil64 = _ceil module.np_ceil = np.ceil # NumPy ndarray functions diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index f81fe0c..9af48f1 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -729,6 +729,28 @@ def test_ndarray_round(): output_ndarray_int64_2(xf64) output_ndarray_float_2(xff) +def test_ndarray_floor(): + x = np_identity(2) + xf32 = floor(x) + xf64 = floor64(x) + xff = np_floor(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(xf32) + output_ndarray_int64_2(xf64) + output_ndarray_float_2(xff) + +def test_ndarray_ceil(): + x = np_identity(2) + xf32 = ceil(x) + xf64 = ceil64(x) + xff = np_ceil(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(xf32) + output_ndarray_int64_2(xf64) + output_ndarray_float_2(xff) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -827,5 +849,6 @@ def run() -> int32: test_ndarray_bool() test_ndarray_round() + test_ndarray_floor() return 0