From 12bdf6f77c915a1229c4d3de8a2ffdaeb2a2efce Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 26 Apr 2024 19:31:15 +0800 Subject: [PATCH] core: WIP - round works now --- nac3core/src/codegen/builtin_fns.rs | 59 +++++++++++++++---- nac3core/src/toplevel/builtins.rs | 49 ++++++++------- nac3core/src/typecheck/type_inferencer/mod.rs | 3 +- nac3standalone/demo/interpret_demo.py | 9 ++- nac3standalone/demo/src/ndarray.py | 9 +++ 5 files changed, 93 insertions(+), 36 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 6506ea8..06ab14a 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,5 +1,5 @@ use inkwell::{FloatPredicate, IntPredicate}; -use inkwell::types::{BasicTypeEnum, IntType}; +use inkwell::types::BasicTypeEnum; use inkwell::values::{BasicValueEnum, FloatValue, IntValue}; use itertools::Itertools; @@ -446,23 +446,58 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( } /// Invokes the `round` builtin function. -pub fn call_round<'ctx>( +pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, FloatValue<'ctx>), - llvm_ret_ty: IntType<'ctx>, -) -> IntValue<'ctx> { + n: (Type, BasicValueEnum<'ctx>), + ret_ty: Type, +) -> Result, String> { const FN_NAME: &str = "round"; + let llvm_usize = generator.get_size_type(ctx.ctx); + let (n_ty, n) = n; + let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty).into_int_type(); - if !ctx.unifier.unioned(n_ty, ctx.primitives.float) { - unsupported_type(ctx, FN_NAME, &[n_ty]) - } + Ok(match n.get_type() { + BasicTypeEnum::FloatType(_) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let val = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder - .build_float_to_signed_int(val, llvm_ret_ty, FN_NAME) - .unwrap() + let val = llvm_intrinsics::call_float_round(ctx, n.into_float_value(), None); + ctx.builder + .build_float_to_signed_int(val, llvm_ret_ty, FN_NAME) + .map(Into::into) + .unwrap() + } + + BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + + let ndarray = ndarray_elementwise_unaryop_impl( + generator, + ctx, + ret_ty, + None, + NDArrayValue::from_ptr_val( + n.into_pointer_value(), + llvm_usize, + None, + ), + |generator, ctx, val| { + call_round( + generator, + ctx, + (elem_ty, val), + ret_ty, + ) + }, + )?; + + ndarray.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[n_ty]) + }) } /// Invokes the `np_round` builtin function. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 00db8c1..28ba443 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -782,23 +782,35 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .map(|val| Some(val.as_basic_value_enum())) }), ), - create_fn_by_codegen( - primitives, - &var_map, - "round", - int32, - &[(float, "n")], - Box::new(|ctx, _, fun, args, generator| { - let llvm_i32 = ctx.ctx.i32_type(); + { + let common_ndim = primitives.1.get_fresh_const_generic_var( + primitives.0.usize(), + Some("N".into()), + None, + ); + let ndarray_int32 = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(int32), Some(common_ndim.0)); + let ndarray_float = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), Some(common_ndim.0)); - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + let p0_ty = primitives.1 + .get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = primitives.1 + .get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None); - Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i32).into())) - }), - ), + create_fn_by_codegen( + primitives, + &var_map, + "round", + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) + }), + ) + }, create_fn_by_codegen( primitives, &var_map, @@ -806,14 +818,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { int64, &[(float, "n")], Box::new(|ctx, _, fun, args, generator| { - let llvm_i64 = ctx.ctx.i64_type(); - let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)? - .into_float_value(); + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i64).into())) + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) }), ), create_fn_by_codegen( diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7fa6c63..6302e2d 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -855,8 +855,9 @@ impl<'a> Inferencer<'a> { "int32", "float", "bool", + "round", ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { - let target_ty = if id == &"int32".into() { + let target_ty = if id == &"int32".into() || id == &"round".into() { self.primitives.int32 } else if id == &"float".into() { self.primitives.float diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index b3c9f69..9511329 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -71,10 +71,13 @@ def _float(x): return float(x) def round_away_zero(x): - if x >= 0.0: - return math.floor(x + 0.5) + if isinstance(x, np.ndarray): + return np.vectorize(round_away_zero)(x) else: - return math.ceil(x - 0.5) + if x >= 0.0: + return math.floor(x + 0.5) + else: + return math.ceil(x - 0.5) def patch(module): def dbl_nan(): diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 4c3e7d0..fc63cc3 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -718,6 +718,13 @@ def test_ndarray_bool(): output_ndarray_float_2(x) output_ndarray_bool_2(y) +def test_ndarray_round(): + x = np_identity(2) + y = round(x) + + output_ndarray_float_2(x) + output_ndarray_int32_2(y) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -815,4 +822,6 @@ def run() -> int32: test_ndarray_float() test_ndarray_bool() + test_ndarray_round() + return 0