forked from M-Labs/nac3
1
0
Fork 0

cargo fmt

This commit is contained in:
lyken 2024-07-12 21:16:38 +08:00
parent c80378063a
commit 2dbc1ec659
2 changed files with 44 additions and 32 deletions

View File

@ -807,7 +807,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a; let (a_ty, a) = a;
Ok( match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([ debug_assert!([
ctx.primitives.bool, ctx.primitives.bool,
@ -819,17 +819,17 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
] ]
.iter() .iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty))); .any(|ty| ctx.unifier.unioned(a_ty, *ty)));
match fn_name { match fn_name {
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(), "np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
"np_max" | "np_min" => a, "np_max" | "np_min" => a,
_ => unreachable!() _ => unreachable!(),
} }
} }
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -865,32 +865,42 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_int64.const_int(1, false), llvm_int64.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, _, idx,| { |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
let result = match fn_name { let result = match fn_name {
"np_argmin" | "np_min" => call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)), "np_argmin" | "np_min" => {
"np_argmax" | "np_max" => call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)), call_min(ctx, (elem_ty, accumulator), (elem_ty, elem))
_ => unreachable!() }
"np_argmax" | "np_max" => {
call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
}
_ => unreachable!(),
}; };
let updated_idx = match (accumulator, result){ let updated_idx = match (accumulator, result) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx
ctx.builder.build_select( .builder
ctx.builder.build_int_compare(IntPredicate::NE,m, n, "").unwrap(), .build_select(
idx.into(), ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(),
idx.into(),
cur_idx, cur_idx,
"").unwrap() "",
}, )
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { .unwrap(),
ctx.builder.build_select( (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => ctx
ctx.builder.build_float_compare(FloatPredicate::ONE,m, n, "").unwrap(), .builder
idx.into(), .build_select(
ctx.builder
.build_float_compare(FloatPredicate::ONE, m, n, "")
.unwrap(),
idx.into(),
cur_idx, cur_idx,
"").unwrap() "",
}, )
.unwrap(),
_ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]),
}; };
ctx.builder.build_store(res_idx, updated_idx).unwrap(); ctx.builder.build_store(res_idx, updated_idx).unwrap();
@ -904,11 +914,11 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
match fn_name { match fn_name {
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
_ => unreachable!() _ => unreachable!(),
} }
} }
_ => unsupported_type(ctx, fn_name, &[a_ty]) _ => unsupported_type(ctx, fn_name, &[a_ty]),
}) })
} }

View File

@ -510,10 +510,9 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim), PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim),
PrimDef::FunNpArgmin PrimDef::FunNpArgmin | PrimDef::FunNpArgmax | PrimDef::FunNpMin | PrimDef::FunNpMax => {
| PrimDef::FunNpArgmax self.build_np_max_min_function(prim)
| PrimDef::FunNpMin }
| PrimDef::FunNpMax => self.build_np_max_min_function(prim),
PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => { PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => {
self.build_np_minimum_maximum_function(prim) self.build_np_minimum_maximum_function(prim)
@ -1561,12 +1560,15 @@ impl<'a> BuiltinBuilder<'a> {
/// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()` /// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()`
/// Calls `call_numpy_max_min` with the function name /// Calls `call_numpy_max_min` with the function name
fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef { fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax]); debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax],
);
let (var_map, ret_ty) = match prim { let (var_map, ret_ty) = match prim {
PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => { PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => {
(self.num_or_ndarray_var_map.clone(), self.primitives.int64) (self.num_or_ndarray_var_map.clone(), self.primitives.int64)
}, }
PrimDef::FunNpMax | PrimDef::FunNpMin => { PrimDef::FunNpMax | PrimDef::FunNpMin => {
let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None);
let var_map = self let var_map = self
@ -1576,8 +1578,8 @@ impl<'a> BuiltinBuilder<'a> {
.chain(once((ret_ty.id, ret_ty.ty))) .chain(once((ret_ty.id, ret_ty.ty)))
.collect::<IndexMap<_, _>>(); .collect::<IndexMap<_, _>>();
(var_map, ret_ty.ty) (var_map, ret_ty.ty)
}, }
_ => unreachable!() _ => unreachable!(),
}; };
create_fn_by_codegen( create_fn_by_codegen(
@ -1589,7 +1591,7 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty; let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), &prim.name())?)) Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), &prim.name())?))
}), }),
) )