core: implement np.any() & np.all()

This commit is contained in:
lyken 2024-06-19 17:28:47 +08:00
parent 1bc95a7ba6
commit 554c9234b8
5 changed files with 255 additions and 0 deletions

View File

@ -2771,3 +2771,182 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
})
}
/// Check if a [`Type`] is an ndarray.
fn is_type_ndarray(ctx: &mut CodeGenContext<'_, '_>, ty: Type) -> bool {
let ndarray = ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap();
ty.obj_id(&ctx.unifier) == Some(ndarray)
}
#[derive(Debug, Clone, Copy)]
enum AnyOrAll {
IsAny,
IsAll,
}
/// Helper function to create `np.any()` and `np.all()`.
///
/// They are mixed together since they are extremely similar in terms of implementation.
fn helper_call_numpy_any_all<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: AnyOrAll,
(in_ty, in_val): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let fn_name = match kind {
AnyOrAll::IsAny => "np_any",
AnyOrAll::IsAll => "np_all",
};
// `neutral` is the value to return when `np.any()`/`np.all()` fails
// `on_hit` is the value to return when `np.any()`/`np.all()` succeeds
let (neutral, on_hit) = match kind {
AnyOrAll::IsAny => (0, 1),
AnyOrAll::IsAll => (1, 0),
};
match in_val {
BasicValueEnum::PointerValue(ndarray_ptr) if is_type_ndarray(ctx, in_ty) => {
/*
NOTE: `np.any(<empty ndarray>)` returns true.
NOTE: `np.all(<empty ndarray>)` returns false.
Here is the reference C code of what the implemented LLVM of `np.any()` is essentially doing.
```c
// np.any(ndarray)
int8 has_true = 0;
for (size_t index = 0; index < ndarray.size; index++) {
Element *elem = ndarray.get(index);
bool elem_is_truthy = call_bool(*elem);
if (elem_is_truthy) {
has_true = 1;
break; // Short-circuiting for performance
}
}
```
*/
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, in_ty);
let in_ndarray = NDArrayValue::from_ptr_val(ndarray_ptr, llvm_usize, None);
let in_ndarray_size =
irrt::call_ndarray_calc_size(generator, ctx, &in_ndarray.dim_sizes(), (None, None));
let llvm_i8 = ctx.ctx.i8_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let current_bb = ctx.builder.get_insert_block().unwrap();
let init_bb =
ctx.ctx.insert_basic_block_after(current_bb, format!("{fn_name}.init").as_str());
let step_bb =
ctx.ctx.insert_basic_block_after(init_bb, format!("{fn_name}.step").as_str());
let check_bb =
ctx.ctx.insert_basic_block_after(step_bb, format!("{fn_name}.check").as_str());
let on_hit_bb =
ctx.ctx.insert_basic_block_after(check_bb, format!("{fn_name}.on_hit").as_str());
let cont_bb =
ctx.ctx.insert_basic_block_after(on_hit_bb, format!("{fn_name}.end").as_str());
// ##### Inserting into `current_bb` #####
ctx.builder.build_unconditional_branch(init_bb).unwrap();
// ##### Inserting into `init_bb` #####
ctx.builder.position_at_end(init_bb);
// The (boolean) result of of `np.any()` or `np.all()`, as an i8. Defaults to neutral
let result_ptr = generator.gen_var_alloc(
ctx,
llvm_i8.into(),
Some(format!("{fn_name}.result").as_str()),
)?;
ctx.builder.build_store(result_ptr, llvm_i8.const_int(neutral, false)).unwrap();
// Index to iterate through the input ndarray in `step_bb`
let index_ptr = generator.gen_var_alloc(
ctx,
llvm_usize.into(),
Some(format!("{fn_name}.index").as_str()),
)?;
ctx.builder.build_store(index_ptr, llvm_usize.const_zero()).unwrap();
// Immediately begin iterating through the ndarray
ctx.builder.build_unconditional_branch(step_bb).unwrap();
// ##### Inserting into `step_bb` #####
ctx.builder.position_at_end(step_bb);
let index = ctx
.builder
.build_load(index_ptr, format!("{fn_name}.i").as_str())
.unwrap()
.into_int_value();
let cond = ctx
.builder
.build_int_compare(
IntPredicate::ULT,
index,
in_ndarray_size,
format!("{fn_name}.cond").as_str(),
)
.unwrap();
// Increment index
let index_next = ctx
.builder
.build_int_add(
index,
llvm_usize.const_int(1, false),
format!("{fn_name}.i_next").as_str(),
)
.unwrap();
ctx.builder.build_store(index_ptr, index_next).unwrap();
ctx.builder.build_conditional_branch(cond, check_bb, cont_bb).unwrap();
// ##### Inserting into `check_bb` #####
ctx.builder.position_at_end(check_bb);
let elem = unsafe { in_ndarray.data().get_unchecked(ctx, generator, &index, None) };
let elem_is_truthy = call_bool(generator, ctx, (elem_ty, elem))?.into_int_value();
// Depending on the `kind` we are working with, the conditional branch's condition may be negated.
// In terms of implementation, instead of doing `if (!elem_is_truthy) a else b`, we do `if (elem_is_truthy) b else a`.
match kind {
AnyOrAll::IsAny => ctx
.builder
.build_conditional_branch(elem_is_truthy, on_hit_bb, step_bb)
.unwrap(),
AnyOrAll::IsAll => ctx
.builder
.build_conditional_branch(elem_is_truthy, step_bb, on_hit_bb)
.unwrap(),
};
// ##### Inserting into `on_hit_bb` #####
ctx.builder.position_at_end(on_hit_bb);
// Short-circuit for performance
ctx.builder.build_store(result_ptr, llvm_i8.const_int(on_hit, false)).unwrap();
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
// Done, reposition the builder to `cont_bb` for continuation
ctx.builder.position_at_end(cont_bb);
let result =
ctx.builder.build_load(result_ptr, format!("{fn_name}.result").as_str()).unwrap();
Ok(result)
}
_ => unsupported_type(ctx, fn_name, &[in_ty]),
}
}
pub fn call_numpy_any<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
helper_call_numpy_any_all(generator, ctx, AnyOrAll::IsAny, arg)
}
pub fn call_numpy_all<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
helper_call_numpy_any_all(generator, ctx, AnyOrAll::IsAll, arg)
}

View File

@ -497,6 +497,8 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
PrimDef::FunNpSin
| PrimDef::FunNpCos
| PrimDef::FunNpTan
@ -1757,6 +1759,29 @@ impl<'a> BuiltinBuilder<'a> {
}
}
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
let param_ty = &[(self.ndarray_num_ty, "a")];
let ret_ty = self.primitives.bool;
// let var_map = &into_var_map([self.ndarray_dtype_tvar, self.ndarray_ndims_tvar]);
let var_map = &into_var_map([]);
let codegen_callback: Box<GenCallCallback> =
Box::new(move |ctx, _obj, (fun_signature, _fun_def_id), args, generator| {
let in_ty = fun_signature.args[0].ty;
let in_val = args[0].1.clone().to_basic_value_enum(ctx, generator, in_ty)?;
let func = match prim {
PrimDef::FunNpAny => builtin_fns::call_numpy_any,
PrimDef::FunNpAll => builtin_fns::call_numpy_all,
_ => unreachable!(),
};
let ret = func(generator, ctx, (in_ty, in_val))?;
Ok(Some(ret))
});
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id())
}

View File

@ -100,6 +100,8 @@ pub enum PrimDef {
FunNpHypot,
FunNpNextAfter,
FunSome,
FunNpAny,
FunNpAll,
}
/// Associated details of a [`PrimDef`]
@ -251,6 +253,8 @@ impl PrimDef {
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunSome => fun("Some", None),
PrimDef::FunNpAny => fun("np_any", None),
PrimDef::FunNpAll => fun("np_all", None),
}
}
}

View File

@ -226,6 +226,8 @@ def patch(module):
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
module.np_any = np.any
module.np_all = np.all
def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename)

View File

@ -1379,6 +1379,48 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_any():
x1 = np_identity(5)
y1 = np_any(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_any(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_any(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_any(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def test_ndarray_all():
x1 = np_identity(5)
y1 = np_all(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_all(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_all(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_all(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -1555,4 +1597,7 @@ def run() -> int32:
test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_any()
test_ndarray_all()
return 0