core: implement np.any() & np.all()
This commit is contained in:
parent
1bc95a7ba6
commit
554c9234b8
|
@ -2771,3 +2771,182 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if a [`Type`] is an ndarray.
|
||||||
|
fn is_type_ndarray(ctx: &mut CodeGenContext<'_, '_>, ty: Type) -> bool {
|
||||||
|
let ndarray = ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap();
|
||||||
|
ty.obj_id(&ctx.unifier) == Some(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
enum AnyOrAll {
|
||||||
|
IsAny,
|
||||||
|
IsAll,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to create `np.any()` and `np.all()`.
|
||||||
|
///
|
||||||
|
/// They are mixed together since they are extremely similar in terms of implementation.
|
||||||
|
fn helper_call_numpy_any_all<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
kind: AnyOrAll,
|
||||||
|
(in_ty, in_val): (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let fn_name = match kind {
|
||||||
|
AnyOrAll::IsAny => "np_any",
|
||||||
|
AnyOrAll::IsAll => "np_all",
|
||||||
|
};
|
||||||
|
|
||||||
|
// `neutral` is the value to return when `np.any()`/`np.all()` fails
|
||||||
|
// `on_hit` is the value to return when `np.any()`/`np.all()` succeeds
|
||||||
|
let (neutral, on_hit) = match kind {
|
||||||
|
AnyOrAll::IsAny => (0, 1),
|
||||||
|
AnyOrAll::IsAll => (1, 0),
|
||||||
|
};
|
||||||
|
|
||||||
|
match in_val {
|
||||||
|
BasicValueEnum::PointerValue(ndarray_ptr) if is_type_ndarray(ctx, in_ty) => {
|
||||||
|
/*
|
||||||
|
NOTE: `np.any(<empty ndarray>)` returns true.
|
||||||
|
NOTE: `np.all(<empty ndarray>)` returns false.
|
||||||
|
|
||||||
|
Here is the reference C code of what the implemented LLVM of `np.any()` is essentially doing.
|
||||||
|
```c
|
||||||
|
// np.any(ndarray)
|
||||||
|
int8 has_true = 0;
|
||||||
|
for (size_t index = 0; index < ndarray.size; index++) {
|
||||||
|
Element *elem = ndarray.get(index);
|
||||||
|
bool elem_is_truthy = call_bool(*elem);
|
||||||
|
if (elem_is_truthy) {
|
||||||
|
has_true = 1;
|
||||||
|
break; // Short-circuiting for performance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, in_ty);
|
||||||
|
|
||||||
|
let in_ndarray = NDArrayValue::from_ptr_val(ndarray_ptr, llvm_usize, None);
|
||||||
|
let in_ndarray_size =
|
||||||
|
irrt::call_ndarray_calc_size(generator, ctx, &in_ndarray.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
|
let init_bb =
|
||||||
|
ctx.ctx.insert_basic_block_after(current_bb, format!("{fn_name}.init").as_str());
|
||||||
|
let step_bb =
|
||||||
|
ctx.ctx.insert_basic_block_after(init_bb, format!("{fn_name}.step").as_str());
|
||||||
|
let check_bb =
|
||||||
|
ctx.ctx.insert_basic_block_after(step_bb, format!("{fn_name}.check").as_str());
|
||||||
|
let on_hit_bb =
|
||||||
|
ctx.ctx.insert_basic_block_after(check_bb, format!("{fn_name}.on_hit").as_str());
|
||||||
|
let cont_bb =
|
||||||
|
ctx.ctx.insert_basic_block_after(on_hit_bb, format!("{fn_name}.end").as_str());
|
||||||
|
|
||||||
|
// ##### Inserting into `current_bb` #####
|
||||||
|
ctx.builder.build_unconditional_branch(init_bb).unwrap();
|
||||||
|
|
||||||
|
// ##### Inserting into `init_bb` #####
|
||||||
|
ctx.builder.position_at_end(init_bb);
|
||||||
|
|
||||||
|
// The (boolean) result of of `np.any()` or `np.all()`, as an i8. Defaults to neutral
|
||||||
|
let result_ptr = generator.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_i8.into(),
|
||||||
|
Some(format!("{fn_name}.result").as_str()),
|
||||||
|
)?;
|
||||||
|
ctx.builder.build_store(result_ptr, llvm_i8.const_int(neutral, false)).unwrap();
|
||||||
|
|
||||||
|
// Index to iterate through the input ndarray in `step_bb`
|
||||||
|
let index_ptr = generator.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_usize.into(),
|
||||||
|
Some(format!("{fn_name}.index").as_str()),
|
||||||
|
)?;
|
||||||
|
ctx.builder.build_store(index_ptr, llvm_usize.const_zero()).unwrap();
|
||||||
|
|
||||||
|
// Immediately begin iterating through the ndarray
|
||||||
|
ctx.builder.build_unconditional_branch(step_bb).unwrap();
|
||||||
|
|
||||||
|
// ##### Inserting into `step_bb` #####
|
||||||
|
ctx.builder.position_at_end(step_bb);
|
||||||
|
let index = ctx
|
||||||
|
.builder
|
||||||
|
.build_load(index_ptr, format!("{fn_name}.i").as_str())
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
let cond = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::ULT,
|
||||||
|
index,
|
||||||
|
in_ndarray_size,
|
||||||
|
format!("{fn_name}.cond").as_str(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Increment index
|
||||||
|
let index_next = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(
|
||||||
|
index,
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
format!("{fn_name}.i_next").as_str(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(index_ptr, index_next).unwrap();
|
||||||
|
|
||||||
|
ctx.builder.build_conditional_branch(cond, check_bb, cont_bb).unwrap();
|
||||||
|
|
||||||
|
// ##### Inserting into `check_bb` #####
|
||||||
|
ctx.builder.position_at_end(check_bb);
|
||||||
|
let elem = unsafe { in_ndarray.data().get_unchecked(ctx, generator, &index, None) };
|
||||||
|
let elem_is_truthy = call_bool(generator, ctx, (elem_ty, elem))?.into_int_value();
|
||||||
|
// Depending on the `kind` we are working with, the conditional branch's condition may be negated.
|
||||||
|
// In terms of implementation, instead of doing `if (!elem_is_truthy) a else b`, we do `if (elem_is_truthy) b else a`.
|
||||||
|
match kind {
|
||||||
|
AnyOrAll::IsAny => ctx
|
||||||
|
.builder
|
||||||
|
.build_conditional_branch(elem_is_truthy, on_hit_bb, step_bb)
|
||||||
|
.unwrap(),
|
||||||
|
AnyOrAll::IsAll => ctx
|
||||||
|
.builder
|
||||||
|
.build_conditional_branch(elem_is_truthy, step_bb, on_hit_bb)
|
||||||
|
.unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// ##### Inserting into `on_hit_bb` #####
|
||||||
|
ctx.builder.position_at_end(on_hit_bb);
|
||||||
|
// Short-circuit for performance
|
||||||
|
ctx.builder.build_store(result_ptr, llvm_i8.const_int(on_hit, false)).unwrap();
|
||||||
|
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||||
|
|
||||||
|
// Done, reposition the builder to `cont_bb` for continuation
|
||||||
|
ctx.builder.position_at_end(cont_bb);
|
||||||
|
let result =
|
||||||
|
ctx.builder.build_load(result_ptr, format!("{fn_name}.result").as_str()).unwrap();
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, fn_name, &[in_ty]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_numpy_any<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
arg: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
helper_call_numpy_any_all(generator, ctx, AnyOrAll::IsAny, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_numpy_all<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
arg: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
helper_call_numpy_any_all(generator, ctx, AnyOrAll::IsAll, arg)
|
||||||
|
}
|
||||||
|
|
|
@ -497,6 +497,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpSin
|
PrimDef::FunNpSin
|
||||||
| PrimDef::FunNpCos
|
| PrimDef::FunNpCos
|
||||||
| PrimDef::FunNpTan
|
| PrimDef::FunNpTan
|
||||||
|
@ -1757,6 +1759,29 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
|
||||||
|
let param_ty = &[(self.ndarray_num_ty, "a")];
|
||||||
|
let ret_ty = self.primitives.bool;
|
||||||
|
// let var_map = &into_var_map([self.ndarray_dtype_tvar, self.ndarray_ndims_tvar]);
|
||||||
|
let var_map = &into_var_map([]);
|
||||||
|
let codegen_callback: Box<GenCallCallback> =
|
||||||
|
Box::new(move |ctx, _obj, (fun_signature, _fun_def_id), args, generator| {
|
||||||
|
let in_ty = fun_signature.args[0].ty;
|
||||||
|
let in_val = args[0].1.clone().to_basic_value_enum(ctx, generator, in_ty)?;
|
||||||
|
|
||||||
|
let func = match prim {
|
||||||
|
PrimDef::FunNpAny => builtin_fns::call_numpy_any,
|
||||||
|
PrimDef::FunNpAll => builtin_fns::call_numpy_all,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let ret = func(generator, ctx, (in_ty, in_val))?;
|
||||||
|
Ok(Some(ret))
|
||||||
|
});
|
||||||
|
|
||||||
|
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
|
||||||
|
}
|
||||||
|
|
||||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||||
(prim.simple_name().into(), method_ty, prim.id())
|
(prim.simple_name().into(), method_ty, prim.id())
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,6 +100,8 @@ pub enum PrimDef {
|
||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
FunSome,
|
FunSome,
|
||||||
|
FunNpAny,
|
||||||
|
FunNpAll,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Associated details of a [`PrimDef`]
|
/// Associated details of a [`PrimDef`]
|
||||||
|
@ -251,6 +253,8 @@ impl PrimDef {
|
||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
PrimDef::FunSome => fun("Some", None),
|
PrimDef::FunSome => fun("Some", None),
|
||||||
|
PrimDef::FunNpAny => fun("np_any", None),
|
||||||
|
PrimDef::FunNpAll => fun("np_all", None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -226,6 +226,8 @@ def patch(module):
|
||||||
module.np_full = np.full
|
module.np_full = np.full
|
||||||
module.np_eye = np.eye
|
module.np_eye = np.eye
|
||||||
module.np_identity = np.identity
|
module.np_identity = np.identity
|
||||||
|
module.np_any = np.any
|
||||||
|
module.np_all = np.all
|
||||||
|
|
||||||
def file_import(filename, prefix="file_import_"):
|
def file_import(filename, prefix="file_import_"):
|
||||||
filename = pathlib.Path(filename)
|
filename = pathlib.Path(filename)
|
||||||
|
|
|
@ -1379,6 +1379,48 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||||
output_ndarray_float_2(nextafter_x_zeros)
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_any():
|
||||||
|
x1 = np_identity(5)
|
||||||
|
y1 = np_any(x1)
|
||||||
|
output_ndarray_float_2(x1)
|
||||||
|
output_bool(y1)
|
||||||
|
|
||||||
|
x2 = np_identity(1)
|
||||||
|
y2 = np_any(x2)
|
||||||
|
output_ndarray_float_2(x2)
|
||||||
|
output_bool(y2)
|
||||||
|
|
||||||
|
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
y3 = np_any(x3)
|
||||||
|
output_ndarray_float_2(x3)
|
||||||
|
output_bool(y3)
|
||||||
|
|
||||||
|
x4 = np_zeros([3, 5])
|
||||||
|
y4 = np_any(x4)
|
||||||
|
output_ndarray_float_2(x4)
|
||||||
|
output_bool(y4)
|
||||||
|
|
||||||
|
def test_ndarray_all():
|
||||||
|
x1 = np_identity(5)
|
||||||
|
y1 = np_all(x1)
|
||||||
|
output_ndarray_float_2(x1)
|
||||||
|
output_bool(y1)
|
||||||
|
|
||||||
|
x2 = np_identity(1)
|
||||||
|
y2 = np_all(x2)
|
||||||
|
output_ndarray_float_2(x2)
|
||||||
|
output_bool(y2)
|
||||||
|
|
||||||
|
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
y3 = np_all(x3)
|
||||||
|
output_ndarray_float_2(x3)
|
||||||
|
output_bool(y3)
|
||||||
|
|
||||||
|
x4 = np_zeros([3, 5])
|
||||||
|
y4 = np_all(x4)
|
||||||
|
output_ndarray_float_2(x4)
|
||||||
|
output_bool(y4)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
|
@ -1555,4 +1597,7 @@ def run() -> int32:
|
||||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||||
|
|
||||||
|
test_ndarray_any()
|
||||||
|
test_ndarray_all()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue