1
0
forked from M-Labs/nac3

core/builtins: Add np_minimum/np_maximum

This commit is contained in:
David Mak 2024-05-08 18:29:11 +08:00
parent 73e81259f3
commit 520e1adc56
10 changed files with 342 additions and 9 deletions

View File

@ -792,6 +792,91 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
/// Invokes the `np_minimum` builtin function.
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_minimum";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
Some(x1_ty)
} else {
None
};
Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { unreachable!() };
let x1_scalar_ty = if is_ndarray1 {
dtype
} else {
x1_ty
};
let x2_scalar_ty = if is_ndarray2 {
dtype
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
dtype,
None,
(x1, !is_ndarray1),
(x2, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
})
}
/// Invokes the `max` builtin function. /// Invokes the `max` builtin function.
pub fn call_max<'ctx>( pub fn call_max<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -925,6 +1010,91 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
/// Invokes the `np_maximum` builtin function.
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_maximum";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
Some(x1_ty)
} else {
None
};
Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
}
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { unreachable!() };
let x1_scalar_ty = if is_ndarray1 {
dtype
} else {
x1_ty
};
let x2_scalar_ty = if is_ndarray2 {
dtype
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
dtype,
None,
(x1, !is_ndarray1),
(x2, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
})
}
/// Invokes the `abs` builtin function. /// Invokes the `abs` builtin function.
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,

View File

@ -1405,6 +1405,43 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
}), }),
) )
}, },
{
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
let ret_ty = unifier.get_fresh_var(None, None);
Arc::new(RwLock::new(TopLevelDef::Function {
name: "np_minimum".into(),
simple_name: "np_minimum".into(),
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty.iter().map(|p| FuncArg {
name: p.1.into(),
ty: p.0,
default_value: None,
}).collect(),
ret: ret_ty.0,
vars: [
(x1_ty.1, x1_ty.0),
(x2_ty.1, x2_ty.0),
(ret_ty.1, ret_ty.0),
].into_iter().collect(),
})),
var_id: vec![x1_ty.1, x2_ty.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x2_ty = fun.0.args[1].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_numpy_minimum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
})))),
loc: None,
}))
},
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "max".into(), name: "max".into(),
simple_name: "max".into(), simple_name: "max".into(),
@ -1454,6 +1491,43 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
}), }),
) )
}, },
{
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
let ret_ty = unifier.get_fresh_var(None, None);
Arc::new(RwLock::new(TopLevelDef::Function {
name: "np_maximum".into(),
simple_name: "np_maximum".into(),
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty.iter().map(|p| FuncArg {
name: p.1.into(),
ty: p.0,
default_value: None,
}).collect(),
ret: ret_ty.0,
vars: [
(x1_ty.1, x1_ty.0),
(x2_ty.1, x2_ty.0),
(ret_ty.1, ret_ty.0),
].into_iter().collect(),
})),
var_id: vec![x1_ty.1, x2_ty.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x2_ty = fun.0.args[1].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_numpy_maximum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
})))),
loc: None,
}))
},
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "abs".into(), name: "abs".into(),
simple_name: "abs".into(), simple_name: "abs".into(),

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [224]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [238]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar213]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar213\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar227]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar227\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [226]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [240]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [231]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [245]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar212, typevar213]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar212\", \"typevar213\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar226, typevar227]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar226\", \"typevar227\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [232]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [246]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [240]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [254]\n}\n",
] ]

View File

@ -951,6 +951,8 @@ impl<'a> Inferencer<'a> {
} }
if [ if [
"np_minimum",
"np_maximum",
"np_arctan2", "np_arctan2",
"np_copysign", "np_copysign",
"np_fmax", "np_fmax",
@ -959,8 +961,6 @@ impl<'a> Inferencer<'a> {
"np_hypot", "np_hypot",
"np_nextafter", "np_nextafter",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 { ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
let target_ty = self.primitives.float;
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap(); let arg0_ty = arg0.custom.unwrap();
let arg1 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?;
@ -977,6 +977,7 @@ impl<'a> Inferencer<'a> {
} else { } else {
arg1_ty arg1_ty
}; };
let expected_arg1_dtype = if id == &"np_ldexp".into() { let expected_arg1_dtype = if id == &"np_ldexp".into() {
self.primitives.int32 self.primitives.int32
} else { } else {
@ -993,6 +994,12 @@ impl<'a> Inferencer<'a> {
) )
} }
let target_ty = if id == &"np_minimum".into() || id == &"np_maximum".into() {
arg0_dtype
} else {
self.primitives.float
};
let ret = if [ let ret = if [
&arg0_ty, &arg0_ty,
&arg1_ty, &arg1_ty,

View File

@ -175,7 +175,9 @@ def patch(module):
module.np_isnan = np.isnan module.np_isnan = np.isnan
module.np_isinf = np.isinf module.np_isinf = np.isinf
module.np_min = np.min module.np_min = np.min
module.np_minimum = np.minimum
module.np_max = np.max module.np_max = np.max
module.np_maximum = np.maximum
module.np_sin = np.sin module.np_sin = np.sin
module.np_cos = np.cos module.np_cos = np.cos
module.np_exp = np.exp module.np_exp = np.exp

View File

@ -766,6 +766,42 @@ def test_ndarray_min():
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_float64(y) output_float64(y)
def test_ndarray_minimum():
x = np_identity(2)
min_x_zeros = np_minimum(x, np_zeros([2]))
min_x_ones = np_minimum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_minimum_broadcast():
x = np_identity(2)
min_x_zeros = np_minimum(x, np_zeros([2]))
min_x_ones = np_minimum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_minimum_broadcast_lhs_scalar():
x = np_identity(2)
min_x_zeros = np_minimum(0.0, x)
min_x_ones = np_minimum(1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_minimum_broadcast_rhs_scalar():
x = np_identity(2)
min_x_zeros = np_minimum(x, 0.0)
min_x_ones = np_minimum(x, 1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_max(): def test_ndarray_max():
x = np_identity(2) x = np_identity(2)
y = np_max(x) y = np_max(x)
@ -773,6 +809,42 @@ def test_ndarray_max():
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_float64(y) output_float64(y)
def test_ndarray_maximum():
x = np_identity(2)
max_x_zeros = np_maximum(x, np_zeros([2]))
max_x_ones = np_maximum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_maximum_broadcast():
x = np_identity(2)
max_x_zeros = np_maximum(x, np_zeros([2]))
max_x_ones = np_maximum(x, np_zeros([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_maximum_broadcast_lhs_scalar():
x = np_identity(2)
max_x_zeros = np_maximum(0.0, x)
max_x_ones = np_maximum(1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_maximum_broadcast_rhs_scalar():
x = np_identity(2)
max_x_zeros = np_maximum(x, 0.0)
max_x_ones = np_maximum(x, 1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_abs(): def test_ndarray_abs():
x = np_identity(2) x = np_identity(2)
y = abs(x) y = abs(x)
@ -1378,7 +1450,15 @@ def run() -> int32:
test_ndarray_round() test_ndarray_round()
test_ndarray_floor() test_ndarray_floor()
test_ndarray_min() test_ndarray_min()
test_ndarray_minimum()
test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_max() test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_abs() test_ndarray_abs()
test_ndarray_isnan() test_ndarray_isnan()
test_ndarray_isinf() test_ndarray_isinf()