forked from M-Labs/nac3
cargo fmt
This commit is contained in:
parent
c80378063a
commit
2dbc1ec659
@ -807,7 +807,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (a_ty, a) = a;
|
||||
Ok( match a {
|
||||
Ok(match a {
|
||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
@ -823,7 +823,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
match fn_name {
|
||||
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
|
||||
"np_max" | "np_min" => a,
|
||||
_ => unreachable!()
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
BasicValueEnum::PointerValue(n)
|
||||
@ -865,32 +865,42 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx,
|
||||
llvm_int64.const_int(1, false),
|
||||
(n_sz, false),
|
||||
|generator, ctx, _, idx,| {
|
||||
|generator, ctx, _, idx| {
|
||||
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
|
||||
|
||||
let result = match fn_name {
|
||||
"np_argmin" | "np_min" => call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)),
|
||||
"np_argmax" | "np_max" => call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)),
|
||||
_ => unreachable!()
|
||||
"np_argmin" | "np_min" => {
|
||||
call_min(ctx, (elem_ty, accumulator), (elem_ty, elem))
|
||||
}
|
||||
"np_argmax" | "np_max" => {
|
||||
call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let updated_idx = match (accumulator, result){
|
||||
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
|
||||
ctx.builder.build_select(
|
||||
ctx.builder.build_int_compare(IntPredicate::NE,m, n, "").unwrap(),
|
||||
let updated_idx = match (accumulator, result) {
|
||||
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(),
|
||||
idx.into(),
|
||||
cur_idx,
|
||||
"").unwrap()
|
||||
},
|
||||
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
|
||||
ctx.builder.build_select(
|
||||
ctx.builder.build_float_compare(FloatPredicate::ONE,m, n, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::ONE, m, n, "")
|
||||
.unwrap(),
|
||||
idx.into(),
|
||||
cur_idx,
|
||||
"").unwrap()
|
||||
},
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
_ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]),
|
||||
};
|
||||
ctx.builder.build_store(res_idx, updated_idx).unwrap();
|
||||
@ -904,11 +914,11 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
match fn_name {
|
||||
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
|
||||
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
|
||||
_ => unreachable!()
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, fn_name, &[a_ty])
|
||||
_ => unsupported_type(ctx, fn_name, &[a_ty]),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -510,10 +510,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
|
||||
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
|
||||
|
||||
PrimDef::FunNpArgmin
|
||||
| PrimDef::FunNpArgmax
|
||||
| PrimDef::FunNpMin
|
||||
| PrimDef::FunNpMax => self.build_np_max_min_function(prim),
|
||||
PrimDef::FunNpArgmin | PrimDef::FunNpArgmax | PrimDef::FunNpMin | PrimDef::FunNpMax => {
|
||||
self.build_np_max_min_function(prim)
|
||||
}
|
||||
|
||||
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
|
||||
self.build_np_minimum_maximum_function(prim)
|
||||
@ -1561,12 +1560,15 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
/// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()`
|
||||
/// Calls `call_numpy_max_min` with the function name
|
||||
fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax]);
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax],
|
||||
);
|
||||
|
||||
let (var_map, ret_ty) = match prim {
|
||||
PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => {
|
||||
(self.num_or_ndarray_var_map.clone(), self.primitives.int64)
|
||||
},
|
||||
}
|
||||
PrimDef::FunNpMax | PrimDef::FunNpMin => {
|
||||
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
|
||||
let var_map = self
|
||||
@ -1576,8 +1578,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
.chain(once((ret_ty.id, ret_ty.ty)))
|
||||
.collect::<IndexMap<_, _>>();
|
||||
(var_map, ret_ty.ty)
|
||||
},
|
||||
_ => unreachable!()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
create_fn_by_codegen(
|
||||
|
Loading…
Reference in New Issue
Block a user