forked from M-Labs/nac3
core: implement ndarray_iter_elem_impl and np.any() & np.all()
This commit is contained in:
parent
42c5f906fb
commit
b1e97aa2b0
@ -7,12 +7,11 @@ use crate::{
|
|||||||
},
|
},
|
||||||
expr::gen_binop_expr_with_values,
|
expr::gen_binop_expr_with_values,
|
||||||
irrt::{
|
irrt::{
|
||||||
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
self, calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
||||||
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
||||||
call_ndarray_calc_size,
|
call_ndarray_calc_size,
|
||||||
},
|
},
|
||||||
llvm_intrinsics,
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
llvm_intrinsics::call_memcpy_generic,
|
|
||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
@ -32,6 +31,8 @@ use inkwell::{
|
|||||||
};
|
};
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
|
use super::{builtin_fns::call_bool, stmt::BreakContinueHooks};
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
@ -1366,6 +1367,46 @@ where
|
|||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for iterating through all elements within an `ndarray`.
|
||||||
|
///
|
||||||
|
/// * `ndarray`: The input [`NDArrayValue`] to iterate through.
|
||||||
|
/// * `body`: A lambda containing IR statements that acts on every element within `ndarray`.
|
||||||
|
/// It may also implement short-circuiting logic by branching with [`BreakContinueHooks`].
|
||||||
|
pub fn ndarray_iter_elem_impl<'ctx, G, BodyFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
body: BodyFn,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
BodyFn: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, '_>,
|
||||||
|
BreakContinueHooks,
|
||||||
|
BasicValueEnum<'ctx>,
|
||||||
|
) -> Result<(), String>,
|
||||||
|
{
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let ndarray_size =
|
||||||
|
irrt::call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(0, false),
|
||||||
|
(ndarray_size, false),
|
||||||
|
|generator, ctx, hooks, idx| {
|
||||||
|
let scalar = unsafe { ndarray.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
body(generator, ctx, hooks, scalar)
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
@ -1984,3 +2025,136 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Used by [`call_ndarray_any_all_impl`]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
enum AnyOrAll {
|
||||||
|
/// The numpy function `np.any()`
|
||||||
|
IsAny,
|
||||||
|
/// The numpy function `np.all()`
|
||||||
|
IsAll,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to create `np.any()` and `np.all()`.
|
||||||
|
///
|
||||||
|
/// Returns a boolean result in the form of an `i8` [`IntValue`].
|
||||||
|
///
|
||||||
|
/// They are mixed together since they are extremely similar in terms of implementation.
|
||||||
|
fn call_ndarray_any_all_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
kind: AnyOrAll,
|
||||||
|
elem_ty: Type,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
|
/*
|
||||||
|
NOTE: `np.any(<empty ndarray>)` returns false.
|
||||||
|
NOTE: `np.all(<empty ndarray>)` returns true.
|
||||||
|
|
||||||
|
Here is the reference C code of what the implemented LLVM of `np.any()` is essentially doing.
|
||||||
|
```c
|
||||||
|
// np.any(ndarray)
|
||||||
|
const int8 neutral = 0;
|
||||||
|
const int8 on_hit = 1;
|
||||||
|
|
||||||
|
int8 has_true = neutral;
|
||||||
|
for (size_t index = 0; index < ndarray.size; index++) {
|
||||||
|
Element *elem = ndarray.get(index);
|
||||||
|
bool elem_is_truthy = call_bool(*elem);
|
||||||
|
if (elem_is_truthy) {
|
||||||
|
has_true = on_hit;
|
||||||
|
break; // Short-circuiting for performance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Name of the function. Used here for creating LLVM labels and names.
|
||||||
|
let fn_name = match kind {
|
||||||
|
AnyOrAll::IsAny => "np_any",
|
||||||
|
AnyOrAll::IsAll => "np_all",
|
||||||
|
};
|
||||||
|
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
|
||||||
|
let (neutral, on_hit) = match kind {
|
||||||
|
AnyOrAll::IsAny => (0, 1),
|
||||||
|
AnyOrAll::IsAll => (1, 0),
|
||||||
|
};
|
||||||
|
|
||||||
|
// The result of `np.any()`/`np.all()`
|
||||||
|
let result_ptr =
|
||||||
|
ctx.builder.build_alloca(llvm_i8, format!("{fn_name}.result").as_str()).unwrap();
|
||||||
|
ctx.builder.build_store(result_ptr, llvm_i8.const_int(neutral, false)).unwrap();
|
||||||
|
|
||||||
|
ndarray_iter_elem_impl(generator, ctx, ndarray, |generator, ctx, hooks, elem| {
|
||||||
|
// The basic block to go to when...
|
||||||
|
// - np.any() sees a `true`, then `result` is set from `false` to `true` and short-circuit.
|
||||||
|
// - np.all() sees a `false`, then `result` is set from `true` to `false` and short-circuit.
|
||||||
|
let on_hit_bb =
|
||||||
|
ctx.ctx.insert_basic_block_after(hooks.break_bb, format!("{fn_name}.on_hit").as_str());
|
||||||
|
|
||||||
|
let elem_is_truthy = call_bool(generator, ctx, (elem_ty, elem)).unwrap().into_int_value();
|
||||||
|
let (on_true, on_false) = match kind {
|
||||||
|
AnyOrAll::IsAny => (on_hit_bb, hooks.continue_bb),
|
||||||
|
AnyOrAll::IsAll => (hooks.continue_bb, on_hit_bb),
|
||||||
|
};
|
||||||
|
ctx.builder.build_conditional_branch(elem_is_truthy, on_true, on_false).unwrap();
|
||||||
|
|
||||||
|
// Begin inserting into `on_hit_bb`
|
||||||
|
ctx.builder.position_at_end(on_hit_bb);
|
||||||
|
ctx.builder.build_store(result_ptr, llvm_i8.const_int(on_hit, false)).unwrap();
|
||||||
|
ctx.builder.build_unconditional_branch(hooks.break_bb).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Load `result` and return it
|
||||||
|
let result = ctx.builder.build_load(result_ptr, "result").unwrap();
|
||||||
|
Ok(result.into_int_value())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to generate LLVM IR for `np.any()` and `np.all()`.
|
||||||
|
fn gen_ndarray_any_all_helper<'ctx>(
|
||||||
|
kind: AnyOrAll,
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
|
|
||||||
|
let in_ty = fun.0.args[0].ty;
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, in_ty);
|
||||||
|
let in_arg = args[0].1.clone().to_basic_value_enum(context, generator, elem_ty)?;
|
||||||
|
|
||||||
|
let ndarray = NDArrayValue::from_ptr_val(in_arg.into_pointer_value(), llvm_usize, None);
|
||||||
|
|
||||||
|
call_ndarray_any_all_impl(generator, context, kind, elem_ty, ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.any()`.
|
||||||
|
pub fn gen_ndarray_any<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
|
gen_ndarray_any_all_helper(AnyOrAll::IsAny, context, obj, fun, args, generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.all()`.
|
||||||
|
pub fn gen_ndarray_all<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
|
gen_ndarray_any_all_helper(AnyOrAll::IsAll, context, obj, fun, args, generator)
|
||||||
|
}
|
||||||
|
@ -462,7 +462,7 @@ pub fn gen_for<G: CodeGenerator>(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub struct BreakContinueHooks<'ctx> {
|
pub struct BreakContinueHooks<'ctx> {
|
||||||
/// [`BasicBlock`] to branch to for `break`-ing out of the loop.
|
/// [`BasicBlock`] to branch to for `break`-ing out of the loop.
|
||||||
pub break_bb: BasicBlock<'ctx>,
|
pub break_bb: BasicBlock<'ctx>,
|
||||||
|
@ -497,6 +497,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
|
|
||||||
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpSin
|
PrimDef::FunNpSin
|
||||||
| PrimDef::FunNpCos
|
| PrimDef::FunNpCos
|
||||||
| PrimDef::FunNpTan
|
| PrimDef::FunNpTan
|
||||||
@ -1757,6 +1759,24 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
|
||||||
|
let param_ty = &[(self.ndarray_num_ty, "a")];
|
||||||
|
let ret_ty = self.primitives.bool;
|
||||||
|
let var_map = &into_var_map([]);
|
||||||
|
let codegen_callback: Box<GenCallCallback> =
|
||||||
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
|
let func = match prim {
|
||||||
|
PrimDef::FunNpAny => gen_ndarray_any,
|
||||||
|
PrimDef::FunNpAll => gen_ndarray_all,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
});
|
||||||
|
|
||||||
|
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
|
||||||
|
}
|
||||||
|
|
||||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||||
(prim.simple_name().into(), method_ty, prim.id())
|
(prim.simple_name().into(), method_ty, prim.id())
|
||||||
}
|
}
|
||||||
|
@ -100,6 +100,8 @@ pub enum PrimDef {
|
|||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
FunSome,
|
FunSome,
|
||||||
|
FunNpAny,
|
||||||
|
FunNpAll,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Associated details of a [`PrimDef`]
|
/// Associated details of a [`PrimDef`]
|
||||||
@ -251,6 +253,8 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
PrimDef::FunSome => fun("Some", None),
|
PrimDef::FunSome => fun("Some", None),
|
||||||
|
PrimDef::FunNpAny => fun("np_any", None),
|
||||||
|
PrimDef::FunNpAll => fun("np_all", None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -226,6 +226,8 @@ def patch(module):
|
|||||||
module.np_full = np.full
|
module.np_full = np.full
|
||||||
module.np_eye = np.eye
|
module.np_eye = np.eye
|
||||||
module.np_identity = np.identity
|
module.np_identity = np.identity
|
||||||
|
module.np_any = np.any
|
||||||
|
module.np_all = np.all
|
||||||
|
|
||||||
def file_import(filename, prefix="file_import_"):
|
def file_import(filename, prefix="file_import_"):
|
||||||
filename = pathlib.Path(filename)
|
filename = pathlib.Path(filename)
|
||||||
|
@ -1388,6 +1388,48 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(nextafter_x_zeros)
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_any():
|
||||||
|
x1 = np_identity(5)
|
||||||
|
y1 = np_any(x1)
|
||||||
|
output_ndarray_float_2(x1)
|
||||||
|
output_bool(y1)
|
||||||
|
|
||||||
|
x2 = np_identity(1)
|
||||||
|
y2 = np_any(x2)
|
||||||
|
output_ndarray_float_2(x2)
|
||||||
|
output_bool(y2)
|
||||||
|
|
||||||
|
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
y3 = np_any(x3)
|
||||||
|
output_ndarray_float_2(x3)
|
||||||
|
output_bool(y3)
|
||||||
|
|
||||||
|
x4 = np_zeros([3, 5])
|
||||||
|
y4 = np_any(x4)
|
||||||
|
output_ndarray_float_2(x4)
|
||||||
|
output_bool(y4)
|
||||||
|
|
||||||
|
def test_ndarray_all():
|
||||||
|
x1 = np_identity(5)
|
||||||
|
y1 = np_all(x1)
|
||||||
|
output_ndarray_float_2(x1)
|
||||||
|
output_bool(y1)
|
||||||
|
|
||||||
|
x2 = np_identity(1)
|
||||||
|
y2 = np_all(x2)
|
||||||
|
output_ndarray_float_2(x2)
|
||||||
|
output_bool(y2)
|
||||||
|
|
||||||
|
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
y3 = np_all(x3)
|
||||||
|
output_ndarray_float_2(x3)
|
||||||
|
output_bool(y3)
|
||||||
|
|
||||||
|
x4 = np_zeros([3, 5])
|
||||||
|
y4 = np_all(x4)
|
||||||
|
output_ndarray_float_2(x4)
|
||||||
|
output_bool(y4)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
@ -1565,4 +1607,7 @@ def run() -> int32:
|
|||||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||||
|
|
||||||
|
test_ndarray_any()
|
||||||
|
test_ndarray_all()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
Loading…
Reference in New Issue
Block a user