diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index aefbac9..f913778 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -807,7 +807,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let (a_ty, a) = a; - Ok( match a { + Ok(match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { debug_assert!([ ctx.primitives.bool, @@ -819,17 +819,17 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ] .iter() .any(|ty| ctx.unifier.unioned(a_ty, *ty))); - + match fn_name { "np_argmin" | "np_argmax" => llvm_int64.const_zero().into(), "np_max" | "np_min" => a, - _ => unreachable!() + _ => unreachable!(), } } BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); @@ -865,32 +865,42 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_int64.const_int(1, false), (n_sz, false), - |generator, ctx, _, idx,| { + |generator, ctx, _, idx| { let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); let result = match fn_name { - "np_argmin" | "np_min" => call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)), - "np_argmax" | "np_max" => call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)), - _ => unreachable!() + "np_argmin" | "np_min" => { + call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) + } + "np_argmax" | "np_max" => { + call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) + } + _ => unreachable!(), }; - let updated_idx = match (accumulator, result){ - (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { - ctx.builder.build_select( - ctx.builder.build_int_compare(IntPredicate::NE,m, n, "").unwrap(), - idx.into(), + let updated_idx = match (accumulator, result) { + (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx + .builder + .build_select( + ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), + idx.into(), cur_idx, - "").unwrap() - }, - (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { - ctx.builder.build_select( - ctx.builder.build_float_compare(FloatPredicate::ONE,m, n, "").unwrap(), - idx.into(), + "", + ) + .unwrap(), + (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => ctx + .builder + .build_select( + ctx.builder + .build_float_compare(FloatPredicate::ONE, m, n, "") + .unwrap(), + idx.into(), cur_idx, - "").unwrap() - }, + "", + ) + .unwrap(), _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), }; ctx.builder.build_store(res_idx, updated_idx).unwrap(); @@ -904,11 +914,11 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( match fn_name { "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), - _ => unreachable!() + _ => unreachable!(), } } - _ => unsupported_type(ctx, fn_name, &[a_ty]) + _ => unsupported_type(ctx, fn_name, &[a_ty]), }) } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 7ad22b9..ab639ed 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -510,10 +510,9 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim), - PrimDef::FunNpArgmin - | PrimDef::FunNpArgmax - | PrimDef::FunNpMin - | PrimDef::FunNpMax => self.build_np_max_min_function(prim), + PrimDef::FunNpArgmin | PrimDef::FunNpArgmax | PrimDef::FunNpMin | PrimDef::FunNpMax => { + self.build_np_max_min_function(prim) + } PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => { self.build_np_minimum_maximum_function(prim) @@ -1561,12 +1560,15 @@ impl<'a> BuiltinBuilder<'a> { /// Build the functions `np_max()`, `np_min()`, `np_argmax()` and `np_argmin()` /// Calls `call_numpy_max_min` with the function name fn build_np_max_min_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax]); + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpArgmin, PrimDef::FunNpArgmax, PrimDef::FunNpMin, PrimDef::FunNpMax], + ); let (var_map, ret_ty) = match prim { PrimDef::FunNpArgmax | PrimDef::FunNpArgmin => { (self.num_or_ndarray_var_map.clone(), self.primitives.int64) - }, + } PrimDef::FunNpMax | PrimDef::FunNpMin => { let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); let var_map = self @@ -1576,8 +1578,8 @@ impl<'a> BuiltinBuilder<'a> { .chain(once((ret_ty.id, ret_ty.ty))) .collect::<IndexMap<_, _>>(); (var_map, ret_ty.ty) - }, - _ => unreachable!() + } + _ => unreachable!(), }; create_fn_by_codegen( @@ -1589,7 +1591,7 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let a_ty = fun.0.args[0].ty; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; - + Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), &prim.name())?)) }), )