From 37a29162c66dd278191ff7ce491f3e18946cb686 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 26 Apr 2024 19:31:15 +0800 Subject: [PATCH] core: WIP - round works now --- nac3core/src/codegen/builtin_fns.rs | 90 ++++++++++++++++----- nac3core/src/toplevel/builtins.rs | 110 ++++++++++++++++++-------- nac3standalone/demo/interpret_demo.py | 9 ++- nac3standalone/demo/src/ndarray.py | 13 +++ 4 files changed, 165 insertions(+), 57 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index fa6c865..58b37ad 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,5 +1,5 @@ use inkwell::{FloatPredicate, IntPredicate}; -use inkwell::types::{BasicTypeEnum, IntType}; +use inkwell::types::BasicTypeEnum; use inkwell::values::{BasicValueEnum, FloatValue, IntValue}; use itertools::Itertools; @@ -406,37 +406,89 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( } /// Invokes the `round` builtin function. -pub fn call_round<'ctx>( +pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), - llvm_ret_ty: IntType<'ctx>, -) -> IntValue<'ctx> { + n: (Type, BasicValueEnum<'ctx>), + ret_ty: Type, +) -> Result, String> { const FN_NAME: &str = "round"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; + let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty).into_int_type(); - if !ctx.unifier.unioned(n_ty, ctx.primitives.float) { - unsupported_type(ctx, FN_NAME, &[n_ty]) - } + Ok(match n.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let val = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder - .build_float_to_signed_int(val, llvm_ret_ty, FN_NAME) - .unwrap() + let val = llvm_intrinsics::call_float_round(ctx, n.into_float_value(), None); + ctx.builder + .build_float_to_signed_int(val, llvm_ret_ty, FN_NAME) + .map(Into::into) + .unwrap() + } + + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_ty, + None, + NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_round(generator, ctx, (elem_ty, val), ret_ty) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + }) } /// Invokes the `np_round` builtin function. -pub fn call_numpy_round<'ctx>( +pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), -) -> FloatValue<'ctx> { + n: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_round"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; - if !ctx.unifier.unioned(n_ty, ctx.primitives.float) { - unsupported_type(ctx, "np_round", &[n_ty]) - } + Ok(match n.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_roundeven(ctx, n, None) + llvm_intrinsics::call_float_roundeven(ctx, n.into_float_value(), None).into() + } + + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ctx.primitives.float, + None, + NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None), + |generator, ctx, val| { + call_numpy_round(generator, ctx, (elem_ty, val)) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + }) } /// Invokes the `bool` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 343b158..4ea6da3 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -303,6 +303,11 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built None, ); let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); + let float_or_ndarray_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); let num_or_ndarray_ty = unifier.get_fresh_var_with_range( &[num_ty.0, ndarray_num_ty], Some("T".into()), @@ -789,53 +794,88 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built .map(|val| Some(val.as_basic_value_enum())) }), ), - create_fn_by_codegen( - unifier, - &var_map, - "round", - int32, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i32 = ctx.ctx.i32_type(); + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int32, ndarray_int32], + Some("R".into()), + None, + ); - Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i32).into())) - }), - ), - create_fn_by_codegen( - unifier, - &var_map, - "round64", - int64, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i64 = ctx.ctx.i64_type(); + create_fn_by_codegen( + unifier, + &var_map, + "round", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + }), + ) + }, + { + let common_ndim = unifier.get_fresh_const_generic_var( + primitives.usize(), + Some("N".into()), + None, + ); + let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i64).into())) - }), - ), + let p0_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float], + Some("T".into()), + None, + ); + let ret_ty = unifier.get_fresh_var_with_range( + &[int64, ndarray_int64], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + unifier, + &var_map, + "round64", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) + }), + ) + }, create_fn_by_codegen( unifier, &var_map, "np_round", - float, - &[(float, "n")], + float_or_ndarray_ty.0, + &[(float_or_ndarray_ty.0, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - Ok(Some(builtin_fns::call_numpy_round(ctx, (arg_ty, arg)).into())) + Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) }), ), Arc::new(RwLock::new(TopLevelDef::Function { diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index b3c9f69..9511329 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -71,10 +71,13 @@ def _float(x): return float(x) def round_away_zero(x): - if x >= 0.0: - return math.floor(x + 0.5) + if isinstance(x, np.ndarray): + return np.vectorize(round_away_zero)(x) else: - return math.ceil(x - 0.5) + if x >= 0.0: + return math.floor(x + 0.5) + else: + return math.ceil(x - 0.5) def patch(module): def dbl_nan(): diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 4c3e7d0..f81fe0c 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -718,6 +718,17 @@ def test_ndarray_bool(): output_ndarray_float_2(x) output_ndarray_bool_2(y) +def test_ndarray_round(): + x = np_identity(2) + xf32 = round(x) + xf64 = round64(x) + xff = np_round(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(xf32) + output_ndarray_int64_2(xf64) + output_ndarray_float_2(xff) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -815,4 +826,6 @@ def run() -> int32: test_ndarray_float() test_ndarray_bool() + test_ndarray_round() + return 0