From b1e97aa2b01bed6d8a6da03fa8e402a616d06f71 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 20 Jun 2024 15:59:33 +0800 Subject: [PATCH] core: implement ndarray_iter_elem_impl and np.any() & np.all() --- nac3core/src/codegen/numpy.rs | 180 +++++++++++++++++++++++++- nac3core/src/codegen/stmt.rs | 2 +- nac3core/src/toplevel/builtins.rs | 20 +++ nac3core/src/toplevel/helper.rs | 4 + nac3standalone/demo/interpret_demo.py | 2 + nac3standalone/demo/src/ndarray.py | 45 +++++++ 6 files changed, 249 insertions(+), 4 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index b082d1ec..3da496d0 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -7,12 +7,11 @@ use crate::{ }, expr::gen_binop_expr_with_values, irrt::{ - calculate_len_for_slice_range, call_ndarray_calc_broadcast, + self, calculate_len_for_slice_range, call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, - llvm_intrinsics, - llvm_intrinsics::call_memcpy_generic, + llvm_intrinsics::{self, call_memcpy_generic}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }, @@ -32,6 +31,8 @@ use inkwell::{ }; use nac3parser::ast::{Operator, StrRef}; +use super::{builtin_fns::call_bool, stmt::BreakContinueHooks}; + /// Creates an uninitialized `NDArray` instance. fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, @@ -1366,6 +1367,46 @@ where Ok(ndarray) } +/// LLVM-typed implementation for iterating through all elements within an `ndarray`. +/// +/// * `ndarray`: The input [`NDArrayValue`] to iterate through. +/// * `body`: A lambda containing IR statements that acts on every element within `ndarray`. +/// It may also implement short-circuiting logic by branching with [`BreakContinueHooks`]. +pub fn ndarray_iter_elem_impl<'ctx, G, BodyFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + body: BodyFn, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + BodyFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, '_>, + BreakContinueHooks, + BasicValueEnum<'ctx>, + ) -> Result<(), String>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + let ndarray_size = + irrt::call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_int(0, false), + (ndarray_size, false), + |generator, ctx, hooks, idx| { + let scalar = unsafe { ndarray.data().get_unchecked(ctx, generator, &idx, None) }; + body(generator, ctx, hooks, scalar) + }, + llvm_usize.const_int(1, false), + )?; + + Ok(()) +} + /// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -1984,3 +2025,136 @@ pub fn gen_ndarray_fill<'ctx>( Ok(()) } + +/// Used by [`call_ndarray_any_all_impl`] +#[derive(Debug, Clone, Copy)] +enum AnyOrAll { + /// The numpy function `np.any()` + IsAny, + /// The numpy function `np.all()` + IsAll, +} + +/// Helper function to create `np.any()` and `np.all()`. +/// +/// Returns a boolean result in the form of an `i8` [`IntValue`]. +/// +/// They are mixed together since they are extremely similar in terms of implementation. +fn call_ndarray_any_all_impl<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + kind: AnyOrAll, + elem_ty: Type, + ndarray: NDArrayValue<'ctx>, +) -> Result, String> { + /* + NOTE: `np.any()` returns false. + NOTE: `np.all()` returns true. + + Here is the reference C code of what the implemented LLVM of `np.any()` is essentially doing. + ```c + // np.any(ndarray) + const int8 neutral = 0; + const int8 on_hit = 1; + + int8 has_true = neutral; + for (size_t index = 0; index < ndarray.size; index++) { + Element *elem = ndarray.get(index); + bool elem_is_truthy = call_bool(*elem); + if (elem_is_truthy) { + has_true = on_hit; + break; // Short-circuiting for performance + } + } + ``` + */ + + // Name of the function. Used here for creating LLVM labels and names. + let fn_name = match kind { + AnyOrAll::IsAny => "np_any", + AnyOrAll::IsAll => "np_all", + }; + + let llvm_i8 = ctx.ctx.i8_type(); + + let (neutral, on_hit) = match kind { + AnyOrAll::IsAny => (0, 1), + AnyOrAll::IsAll => (1, 0), + }; + + // The result of `np.any()`/`np.all()` + let result_ptr = + ctx.builder.build_alloca(llvm_i8, format!("{fn_name}.result").as_str()).unwrap(); + ctx.builder.build_store(result_ptr, llvm_i8.const_int(neutral, false)).unwrap(); + + ndarray_iter_elem_impl(generator, ctx, ndarray, |generator, ctx, hooks, elem| { + // The basic block to go to when... + // - np.any() sees a `true`, then `result` is set from `false` to `true` and short-circuit. + // - np.all() sees a `false`, then `result` is set from `true` to `false` and short-circuit. + let on_hit_bb = + ctx.ctx.insert_basic_block_after(hooks.break_bb, format!("{fn_name}.on_hit").as_str()); + + let elem_is_truthy = call_bool(generator, ctx, (elem_ty, elem)).unwrap().into_int_value(); + let (on_true, on_false) = match kind { + AnyOrAll::IsAny => (on_hit_bb, hooks.continue_bb), + AnyOrAll::IsAll => (hooks.continue_bb, on_hit_bb), + }; + ctx.builder.build_conditional_branch(elem_is_truthy, on_true, on_false).unwrap(); + + // Begin inserting into `on_hit_bb` + ctx.builder.position_at_end(on_hit_bb); + ctx.builder.build_store(result_ptr, llvm_i8.const_int(on_hit, false)).unwrap(); + ctx.builder.build_unconditional_branch(hooks.break_bb).unwrap(); + + Ok(()) + })?; + + // Load `result` and return it + let result = ctx.builder.build_load(result_ptr, "result").unwrap(); + Ok(result.into_int_value()) +} + +/// Helper function to generate LLVM IR for `np.any()` and `np.all()`. +fn gen_ndarray_any_all_helper<'ctx>( + kind: AnyOrAll, + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let llvm_usize = generator.get_size_type(context.ctx); + + let in_ty = fun.0.args[0].ty; + let (elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, in_ty); + let in_arg = args[0].1.clone().to_basic_value_enum(context, generator, elem_ty)?; + + let ndarray = NDArrayValue::from_ptr_val(in_arg.into_pointer_value(), llvm_usize, None); + + call_ndarray_any_all_impl(generator, context, kind, elem_ty, ndarray) +} + +/// Generates LLVM IR for `np.any()`. +pub fn gen_ndarray_any<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + gen_ndarray_any_all_helper(AnyOrAll::IsAny, context, obj, fun, args, generator) +} + +/// Generates LLVM IR for `np.all()`. +pub fn gen_ndarray_all<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + gen_ndarray_any_all_helper(AnyOrAll::IsAll, context, obj, fun, args, generator) +} diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 8ffc3621..f6d865a3 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -462,7 +462,7 @@ pub fn gen_for( Ok(()) } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct BreakContinueHooks<'ctx> { /// [`BasicBlock`] to branch to for `break`-ing out of the loop. pub break_bb: BasicBlock<'ctx>, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index ef6618bc..99b30c10 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -497,6 +497,8 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), + PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim), + PrimDef::FunNpSin | PrimDef::FunNpCos | PrimDef::FunNpTan @@ -1757,6 +1759,24 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]); + let param_ty = &[(self.ndarray_num_ty, "a")]; + let ret_ty = self.primitives.bool; + let var_map = &into_var_map([]); + let codegen_callback: Box = + Box::new(move |ctx, obj, fun, args, generator| { + let func = match prim { + PrimDef::FunNpAny => gen_ndarray_any, + PrimDef::FunNpAll => gen_ndarray_all, + _ => unreachable!(), + }; + func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) + }); + + create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback) + } + fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { (prim.simple_name().into(), method_ty, prim.id()) } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 2fcc24c4..492d7ef5 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -100,6 +100,8 @@ pub enum PrimDef { FunNpHypot, FunNpNextAfter, FunSome, + FunNpAny, + FunNpAll, } /// Associated details of a [`PrimDef`] @@ -251,6 +253,8 @@ impl PrimDef { PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunSome => fun("Some", None), + PrimDef::FunNpAny => fun("np_any", None), + PrimDef::FunNpAll => fun("np_all", None), } } } diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 1b68bea6..b28322e9 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -226,6 +226,8 @@ def patch(module): module.np_full = np.full module.np_eye = np.eye module.np_identity = np.identity + module.np_any = np.any + module.np_all = np.all def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index eb124351..ea92c89c 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1388,6 +1388,48 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) +def test_ndarray_any(): + x1 = np_identity(5) + y1 = np_any(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_any(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_any(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_any(x4) + output_ndarray_float_2(x4) + output_bool(y4) + +def test_ndarray_all(): + x1 = np_identity(5) + y1 = np_all(x1) + output_ndarray_float_2(x1) + output_bool(y1) + + x2 = np_identity(1) + y2 = np_all(x2) + output_ndarray_float_2(x2) + output_bool(y2) + + x3 = np_array([[1.0, 2.0], [3.0, 4.0]]) + y3 = np_all(x3) + output_ndarray_float_2(x3) + output_bool(y3) + + x4 = np_zeros([3, 5]) + y4 = np_all(x4) + output_ndarray_float_2(x4) + output_bool(y4) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -1565,4 +1607,7 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() + test_ndarray_any() + test_ndarray_all() + return 0