core: implement np.any() & np.all()
This commit is contained in:
parent
1bc95a7ba6
commit
554c9234b8
|
@ -2771,3 +2771,182 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
|||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a [`Type`] is an ndarray.
|
||||
fn is_type_ndarray(ctx: &mut CodeGenContext<'_, '_>, ty: Type) -> bool {
|
||||
let ndarray = ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap();
|
||||
ty.obj_id(&ctx.unifier) == Some(ndarray)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum AnyOrAll {
|
||||
IsAny,
|
||||
IsAll,
|
||||
}
|
||||
|
||||
/// Helper function to create `np.any()` and `np.all()`.
|
||||
///
|
||||
/// They are mixed together since they are extremely similar in terms of implementation.
|
||||
fn helper_call_numpy_any_all<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: AnyOrAll,
|
||||
(in_ty, in_val): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let fn_name = match kind {
|
||||
AnyOrAll::IsAny => "np_any",
|
||||
AnyOrAll::IsAll => "np_all",
|
||||
};
|
||||
|
||||
// `neutral` is the value to return when `np.any()`/`np.all()` fails
|
||||
// `on_hit` is the value to return when `np.any()`/`np.all()` succeeds
|
||||
let (neutral, on_hit) = match kind {
|
||||
AnyOrAll::IsAny => (0, 1),
|
||||
AnyOrAll::IsAll => (1, 0),
|
||||
};
|
||||
|
||||
match in_val {
|
||||
BasicValueEnum::PointerValue(ndarray_ptr) if is_type_ndarray(ctx, in_ty) => {
|
||||
/*
|
||||
NOTE: `np.any(<empty ndarray>)` returns true.
|
||||
NOTE: `np.all(<empty ndarray>)` returns false.
|
||||
|
||||
Here is the reference C code of what the implemented LLVM of `np.any()` is essentially doing.
|
||||
```c
|
||||
// np.any(ndarray)
|
||||
int8 has_true = 0;
|
||||
for (size_t index = 0; index < ndarray.size; index++) {
|
||||
Element *elem = ndarray.get(index);
|
||||
bool elem_is_truthy = call_bool(*elem);
|
||||
if (elem_is_truthy) {
|
||||
has_true = 1;
|
||||
break; // Short-circuiting for performance
|
||||
}
|
||||
}
|
||||
```
|
||||
*/
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, in_ty);
|
||||
|
||||
let in_ndarray = NDArrayValue::from_ptr_val(ndarray_ptr, llvm_usize, None);
|
||||
let in_ndarray_size =
|
||||
irrt::call_ndarray_calc_size(generator, ctx, &in_ndarray.dim_sizes(), (None, None));
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let init_bb =
|
||||
ctx.ctx.insert_basic_block_after(current_bb, format!("{fn_name}.init").as_str());
|
||||
let step_bb =
|
||||
ctx.ctx.insert_basic_block_after(init_bb, format!("{fn_name}.step").as_str());
|
||||
let check_bb =
|
||||
ctx.ctx.insert_basic_block_after(step_bb, format!("{fn_name}.check").as_str());
|
||||
let on_hit_bb =
|
||||
ctx.ctx.insert_basic_block_after(check_bb, format!("{fn_name}.on_hit").as_str());
|
||||
let cont_bb =
|
||||
ctx.ctx.insert_basic_block_after(on_hit_bb, format!("{fn_name}.end").as_str());
|
||||
|
||||
// ##### Inserting into `current_bb` #####
|
||||
ctx.builder.build_unconditional_branch(init_bb).unwrap();
|
||||
|
||||
// ##### Inserting into `init_bb` #####
|
||||
ctx.builder.position_at_end(init_bb);
|
||||
|
||||
// The (boolean) result of of `np.any()` or `np.all()`, as an i8. Defaults to neutral
|
||||
let result_ptr = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_i8.into(),
|
||||
Some(format!("{fn_name}.result").as_str()),
|
||||
)?;
|
||||
ctx.builder.build_store(result_ptr, llvm_i8.const_int(neutral, false)).unwrap();
|
||||
|
||||
// Index to iterate through the input ndarray in `step_bb`
|
||||
let index_ptr = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_usize.into(),
|
||||
Some(format!("{fn_name}.index").as_str()),
|
||||
)?;
|
||||
ctx.builder.build_store(index_ptr, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
// Immediately begin iterating through the ndarray
|
||||
ctx.builder.build_unconditional_branch(step_bb).unwrap();
|
||||
|
||||
// ##### Inserting into `step_bb` #####
|
||||
ctx.builder.position_at_end(step_bb);
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_load(index_ptr, format!("{fn_name}.i").as_str())
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let cond = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::ULT,
|
||||
index,
|
||||
in_ndarray_size,
|
||||
format!("{fn_name}.cond").as_str(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Increment index
|
||||
let index_next = ctx
|
||||
.builder
|
||||
.build_int_add(
|
||||
index,
|
||||
llvm_usize.const_int(1, false),
|
||||
format!("{fn_name}.i_next").as_str(),
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(index_ptr, index_next).unwrap();
|
||||
|
||||
ctx.builder.build_conditional_branch(cond, check_bb, cont_bb).unwrap();
|
||||
|
||||
// ##### Inserting into `check_bb` #####
|
||||
ctx.builder.position_at_end(check_bb);
|
||||
let elem = unsafe { in_ndarray.data().get_unchecked(ctx, generator, &index, None) };
|
||||
let elem_is_truthy = call_bool(generator, ctx, (elem_ty, elem))?.into_int_value();
|
||||
// Depending on the `kind` we are working with, the conditional branch's condition may be negated.
|
||||
// In terms of implementation, instead of doing `if (!elem_is_truthy) a else b`, we do `if (elem_is_truthy) b else a`.
|
||||
match kind {
|
||||
AnyOrAll::IsAny => ctx
|
||||
.builder
|
||||
.build_conditional_branch(elem_is_truthy, on_hit_bb, step_bb)
|
||||
.unwrap(),
|
||||
AnyOrAll::IsAll => ctx
|
||||
.builder
|
||||
.build_conditional_branch(elem_is_truthy, step_bb, on_hit_bb)
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
// ##### Inserting into `on_hit_bb` #####
|
||||
ctx.builder.position_at_end(on_hit_bb);
|
||||
// Short-circuit for performance
|
||||
ctx.builder.build_store(result_ptr, llvm_i8.const_int(on_hit, false)).unwrap();
|
||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||
|
||||
// Done, reposition the builder to `cont_bb` for continuation
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
let result =
|
||||
ctx.builder.build_load(result_ptr, format!("{fn_name}.result").as_str()).unwrap();
|
||||
Ok(result)
|
||||
}
|
||||
_ => unsupported_type(ctx, fn_name, &[in_ty]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_numpy_any<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
arg: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
helper_call_numpy_any_all(generator, ctx, AnyOrAll::IsAny, arg)
|
||||
}
|
||||
|
||||
pub fn call_numpy_all<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
arg: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
helper_call_numpy_any_all(generator, ctx, AnyOrAll::IsAll, arg)
|
||||
}
|
||||
|
|
|
@ -497,6 +497,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
|
||||
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
||||
|
||||
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
|
||||
|
||||
PrimDef::FunNpSin
|
||||
| PrimDef::FunNpCos
|
||||
| PrimDef::FunNpTan
|
||||
|
@ -1757,6 +1759,29 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
|
||||
let param_ty = &[(self.ndarray_num_ty, "a")];
|
||||
let ret_ty = self.primitives.bool;
|
||||
// let var_map = &into_var_map([self.ndarray_dtype_tvar, self.ndarray_ndims_tvar]);
|
||||
let var_map = &into_var_map([]);
|
||||
let codegen_callback: Box<GenCallCallback> =
|
||||
Box::new(move |ctx, _obj, (fun_signature, _fun_def_id), args, generator| {
|
||||
let in_ty = fun_signature.args[0].ty;
|
||||
let in_val = args[0].1.clone().to_basic_value_enum(ctx, generator, in_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpAny => builtin_fns::call_numpy_any,
|
||||
PrimDef::FunNpAll => builtin_fns::call_numpy_all,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let ret = func(generator, ctx, (in_ty, in_val))?;
|
||||
Ok(Some(ret))
|
||||
});
|
||||
|
||||
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
|
||||
}
|
||||
|
||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||
(prim.simple_name().into(), method_ty, prim.id())
|
||||
}
|
||||
|
|
|
@ -100,6 +100,8 @@ pub enum PrimDef {
|
|||
FunNpHypot,
|
||||
FunNpNextAfter,
|
||||
FunSome,
|
||||
FunNpAny,
|
||||
FunNpAll,
|
||||
}
|
||||
|
||||
/// Associated details of a [`PrimDef`]
|
||||
|
@ -251,6 +253,8 @@ impl PrimDef {
|
|||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunSome => fun("Some", None),
|
||||
PrimDef::FunNpAny => fun("np_any", None),
|
||||
PrimDef::FunNpAll => fun("np_all", None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -226,6 +226,8 @@ def patch(module):
|
|||
module.np_full = np.full
|
||||
module.np_eye = np.eye
|
||||
module.np_identity = np.identity
|
||||
module.np_any = np.any
|
||||
module.np_all = np.all
|
||||
|
||||
def file_import(filename, prefix="file_import_"):
|
||||
filename = pathlib.Path(filename)
|
||||
|
|
|
@ -1379,6 +1379,48 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
|||
output_ndarray_float_2(nextafter_x_zeros)
|
||||
output_ndarray_float_2(nextafter_x_ones)
|
||||
|
||||
def test_ndarray_any():
|
||||
x1 = np_identity(5)
|
||||
y1 = np_any(x1)
|
||||
output_ndarray_float_2(x1)
|
||||
output_bool(y1)
|
||||
|
||||
x2 = np_identity(1)
|
||||
y2 = np_any(x2)
|
||||
output_ndarray_float_2(x2)
|
||||
output_bool(y2)
|
||||
|
||||
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||
y3 = np_any(x3)
|
||||
output_ndarray_float_2(x3)
|
||||
output_bool(y3)
|
||||
|
||||
x4 = np_zeros([3, 5])
|
||||
y4 = np_any(x4)
|
||||
output_ndarray_float_2(x4)
|
||||
output_bool(y4)
|
||||
|
||||
def test_ndarray_all():
|
||||
x1 = np_identity(5)
|
||||
y1 = np_all(x1)
|
||||
output_ndarray_float_2(x1)
|
||||
output_bool(y1)
|
||||
|
||||
x2 = np_identity(1)
|
||||
y2 = np_all(x2)
|
||||
output_ndarray_float_2(x2)
|
||||
output_bool(y2)
|
||||
|
||||
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||
y3 = np_all(x3)
|
||||
output_ndarray_float_2(x3)
|
||||
output_bool(y3)
|
||||
|
||||
x4 = np_zeros([3, 5])
|
||||
y4 = np_all(x4)
|
||||
output_ndarray_float_2(x4)
|
||||
output_bool(y4)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -1555,4 +1597,7 @@ def run() -> int32:
|
|||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||
|
||||
test_ndarray_any()
|
||||
test_ndarray_all()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue