From 520e1adc56199cefc1c571fbe656637930157458 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 8 May 2024 18:29:11 +0800 Subject: [PATCH] core/builtins: Add np_minimum/np_maximum --- nac3core/src/codegen/builtin_fns.rs | 170 ++++++++++++++++++ nac3core/src/toplevel/builtins.rs | 74 ++++++++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 11 +- nac3standalone/demo/interpret_demo.py | 2 + nac3standalone/demo/src/ndarray.py | 80 +++++++++ 10 files changed, 342 insertions(+), 9 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index e0067bb5a..1fbfd7121 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -792,6 +792,91 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( }) } +/// Invokes the `np_minimum` builtin function. +pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_minimum"; + + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; + + let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { + Some(x1_ty) + } else { + None + }; + + Ok(match (x1, x2) { + (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { + debug_assert!([ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ctx.primitives.float, + ].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + + call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + } + + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); + + call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) +} + /// Invokes the `max` builtin function. pub fn call_max<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, @@ -925,6 +1010,91 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( }) } +/// Invokes the `np_maximum` builtin function. +pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_maximum"; + + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; + + let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { + Some(x1_ty) + } else { + None + }; + + Ok(match (x1, x2) { + (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { + debug_assert!([ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ctx.primitives.float, + ].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty))); + + call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + } + + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); + + call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) + } + + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => { + let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + let dtype = if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + + debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + ndarray_dtype1 + } else if is_ndarray1 { + unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + } else if is_ndarray2 { + unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + } else { unreachable!() }; + + let x1_scalar_ty = if is_ndarray1 { + dtype + } else { + x1_ty + }; + let x2_scalar_ty = if is_ndarray2 { + dtype + } else { + x2_ty + }; + + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + dtype, + None, + (x1, !is_ndarray1), + (x2, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + }, + )?.as_ptr_value().into() + } + + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + }) +} + /// Invokes the `abs` builtin function. pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 7187ae8f0..56026170b 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1405,6 +1405,43 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }), ) }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); + + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_minimum".into(), + simple_name: "np_minimum".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![x1_ty.1, x2_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x2_ty = fun.0.args[1].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_minimum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + })))), + loc: None, + })) + }, Arc::new(RwLock::new(TopLevelDef::Function { name: "max".into(), simple_name: "max".into(), @@ -1454,6 +1491,43 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }), ) }, + { + let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); + let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = unifier.get_fresh_var(None, None); + + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_maximum".into(), + simple_name: "np_maximum".into(), + signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty.iter().map(|p| FuncArg { + name: p.1.into(), + ty: p.0, + default_value: None, + }).collect(), + ret: ret_ty.0, + vars: [ + (x1_ty.1, x1_ty.0), + (x2_ty.1, x2_ty.0), + (ret_ty.1, ret_ty.0), + ].into_iter().collect(), + })), + var_id: vec![x1_ty.1, x2_ty.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x2_ty = fun.0.args[1].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_numpy_maximum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + })))), + loc: None, + })) + }, Arc::new(RwLock::new(TopLevelDef::Function { name: "abs".into(), simple_name: "abs".into(), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 65ceb2497..7beeae4a1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [224]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [238]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2ee8e8fc5..0c8efccd7 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar213]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar213\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar227]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar227\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 40b968863..f3e5fdeaa 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [226]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [231]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [240]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [245]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 7204c0d45..279eab9d2 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar212, typevar213]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar212\", \"typevar213\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar226, typevar227]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar226\", \"typevar227\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 1a526c5ea..24ae22382 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [232]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [246]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [240]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [254]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 61c56c6b5..c26366d81 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -951,6 +951,8 @@ impl<'a> Inferencer<'a> { } if [ + "np_minimum", + "np_maximum", "np_arctan2", "np_copysign", "np_fmax", @@ -959,8 +961,6 @@ impl<'a> Inferencer<'a> { "np_hypot", "np_nextafter", ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 { - let target_ty = self.primitives.float; - let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); let arg1 = self.fold_expr(args.remove(0))?; @@ -977,6 +977,7 @@ impl<'a> Inferencer<'a> { } else { arg1_ty }; + let expected_arg1_dtype = if id == &"np_ldexp".into() { self.primitives.int32 } else { @@ -993,6 +994,12 @@ impl<'a> Inferencer<'a> { ) } + let target_ty = if id == &"np_minimum".into() || id == &"np_maximum".into() { + arg0_dtype + } else { + self.primitives.float + }; + let ret = if [ &arg0_ty, &arg1_ty, diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index b5fb153ae..a26418427 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -175,7 +175,9 @@ def patch(module): module.np_isnan = np.isnan module.np_isinf = np.isinf module.np_min = np.min + module.np_minimum = np.minimum module.np_max = np.max + module.np_maximum = np.maximum module.np_sin = np.sin module.np_cos = np.cos module.np_exp = np.exp diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 66f6fe937..919824296 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -766,6 +766,42 @@ def test_ndarray_min(): output_ndarray_float_2(x) output_float64(y) +def test_ndarray_minimum(): + x = np_identity(2) + min_x_zeros = np_minimum(x, np_zeros([2])) + min_x_ones = np_minimum(x, np_zeros([2])) + + output_ndarray_float_2(x) + output_ndarray_float_2(min_x_zeros) + output_ndarray_float_2(min_x_ones) + +def test_ndarray_minimum_broadcast(): + x = np_identity(2) + min_x_zeros = np_minimum(x, np_zeros([2])) + min_x_ones = np_minimum(x, np_zeros([2])) + + output_ndarray_float_2(x) + output_ndarray_float_2(min_x_zeros) + output_ndarray_float_2(min_x_ones) + +def test_ndarray_minimum_broadcast_lhs_scalar(): + x = np_identity(2) + min_x_zeros = np_minimum(0.0, x) + min_x_ones = np_minimum(1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(min_x_zeros) + output_ndarray_float_2(min_x_ones) + +def test_ndarray_minimum_broadcast_rhs_scalar(): + x = np_identity(2) + min_x_zeros = np_minimum(x, 0.0) + min_x_ones = np_minimum(x, 1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(min_x_zeros) + output_ndarray_float_2(min_x_ones) + def test_ndarray_max(): x = np_identity(2) y = np_max(x) @@ -773,6 +809,42 @@ def test_ndarray_max(): output_ndarray_float_2(x) output_float64(y) +def test_ndarray_maximum(): + x = np_identity(2) + max_x_zeros = np_maximum(x, np_zeros([2])) + max_x_ones = np_maximum(x, np_zeros([2])) + + output_ndarray_float_2(x) + output_ndarray_float_2(max_x_zeros) + output_ndarray_float_2(max_x_ones) + +def test_ndarray_maximum_broadcast(): + x = np_identity(2) + max_x_zeros = np_maximum(x, np_zeros([2])) + max_x_ones = np_maximum(x, np_zeros([2])) + + output_ndarray_float_2(x) + output_ndarray_float_2(max_x_zeros) + output_ndarray_float_2(max_x_ones) + +def test_ndarray_maximum_broadcast_lhs_scalar(): + x = np_identity(2) + max_x_zeros = np_maximum(0.0, x) + max_x_ones = np_maximum(1.0, x) + + output_ndarray_float_2(x) + output_ndarray_float_2(max_x_zeros) + output_ndarray_float_2(max_x_ones) + +def test_ndarray_maximum_broadcast_rhs_scalar(): + x = np_identity(2) + max_x_zeros = np_maximum(x, 0.0) + max_x_ones = np_maximum(x, 1.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(max_x_zeros) + output_ndarray_float_2(max_x_ones) + def test_ndarray_abs(): x = np_identity(2) y = abs(x) @@ -1378,7 +1450,15 @@ def run() -> int32: test_ndarray_round() test_ndarray_floor() test_ndarray_min() + test_ndarray_minimum() + test_ndarray_minimum_broadcast() + test_ndarray_minimum_broadcast_lhs_scalar() + test_ndarray_minimum_broadcast_rhs_scalar() test_ndarray_max() + test_ndarray_maximum() + test_ndarray_maximum_broadcast() + test_ndarray_maximum_broadcast_lhs_scalar() + test_ndarray_maximum_broadcast_rhs_scalar() test_ndarray_abs() test_ndarray_isnan() test_ndarray_isinf()