forked from M-Labs/nac3
core/builtins: Add np_minimum/np_maximum
This commit is contained in:
parent
73e81259f3
commit
520e1adc56
@ -792,6 +792,91 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_minimum` builtin function.
|
||||||
|
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_minimum";
|
||||||
|
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (x2_ty, x2) = x2;
|
||||||
|
|
||||||
|
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
|
||||||
|
Some(x1_ty)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(match (x1, x2) {
|
||||||
|
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||||
|
debug_assert!([
|
||||||
|
ctx.primitives.bool,
|
||||||
|
ctx.primitives.int32,
|
||||||
|
ctx.primitives.uint32,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
ctx.primitives.uint64,
|
||||||
|
ctx.primitives.float,
|
||||||
|
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
|
||||||
|
|
||||||
|
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||||
|
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
|
||||||
|
|
||||||
|
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
|
||||||
|
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
|
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||||
|
|
||||||
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
|
ndarray_dtype1
|
||||||
|
} else if is_ndarray1 {
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||||
|
} else if is_ndarray2 {
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||||
|
} else { unreachable!() };
|
||||||
|
|
||||||
|
let x1_scalar_ty = if is_ndarray1 {
|
||||||
|
dtype
|
||||||
|
} else {
|
||||||
|
x1_ty
|
||||||
|
};
|
||||||
|
let x2_scalar_ty = if is_ndarray2 {
|
||||||
|
dtype
|
||||||
|
} else {
|
||||||
|
x2_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
dtype,
|
||||||
|
None,
|
||||||
|
(x1, !is_ndarray1),
|
||||||
|
(x2, !is_ndarray2),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
|
},
|
||||||
|
)?.as_ptr_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `max` builtin function.
|
/// Invokes the `max` builtin function.
|
||||||
pub fn call_max<'ctx>(
|
pub fn call_max<'ctx>(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
@ -925,6 +1010,91 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `np_maximum` builtin function.
|
||||||
|
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_maximum";
|
||||||
|
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (x2_ty, x2) = x2;
|
||||||
|
|
||||||
|
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) {
|
||||||
|
Some(x1_ty)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(match (x1, x2) {
|
||||||
|
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||||
|
debug_assert!([
|
||||||
|
ctx.primitives.bool,
|
||||||
|
ctx.primitives.int32,
|
||||||
|
ctx.primitives.uint32,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
ctx.primitives.uint64,
|
||||||
|
ctx.primitives.float,
|
||||||
|
].iter().any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
|
||||||
|
|
||||||
|
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||||
|
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
|
||||||
|
|
||||||
|
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
(x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) => {
|
||||||
|
let is_ndarray1 = x1_ty.obj_id(&ctx.unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
let is_ndarray2 = x2_ty.obj_id(&ctx.unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
|
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||||
|
|
||||||
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
|
ndarray_dtype1
|
||||||
|
} else if is_ndarray1 {
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||||
|
} else if is_ndarray2 {
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||||
|
} else { unreachable!() };
|
||||||
|
|
||||||
|
let x1_scalar_ty = if is_ndarray1 {
|
||||||
|
dtype
|
||||||
|
} else {
|
||||||
|
x1_ty
|
||||||
|
};
|
||||||
|
let x2_scalar_ty = if is_ndarray2 {
|
||||||
|
dtype
|
||||||
|
} else {
|
||||||
|
x2_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
dtype,
|
||||||
|
None,
|
||||||
|
(x1, !is_ndarray1),
|
||||||
|
(x2, !is_ndarray2),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
|
||||||
|
},
|
||||||
|
)?.as_ptr_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `abs` builtin function.
|
/// Invokes the `abs` builtin function.
|
||||||
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -1405,6 +1405,43 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||||
|
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||||
|
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
|
||||||
|
let ret_ty = unifier.get_fresh_var(None, None);
|
||||||
|
|
||||||
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
|
name: "np_minimum".into(),
|
||||||
|
simple_name: "np_minimum".into(),
|
||||||
|
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: param_ty.iter().map(|p| FuncArg {
|
||||||
|
name: p.1.into(),
|
||||||
|
ty: p.0,
|
||||||
|
default_value: None,
|
||||||
|
}).collect(),
|
||||||
|
ret: ret_ty.0,
|
||||||
|
vars: [
|
||||||
|
(x1_ty.1, x1_ty.0),
|
||||||
|
(x2_ty.1, x2_ty.0),
|
||||||
|
(ret_ty.1, ret_ty.0),
|
||||||
|
].into_iter().collect(),
|
||||||
|
})),
|
||||||
|
var_id: vec![x1_ty.1, x2_ty.1],
|
||||||
|
instance_to_symbol: HashMap::default(),
|
||||||
|
instance_to_stmt: HashMap::default(),
|
||||||
|
resolver: None,
|
||||||
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_numpy_minimum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
|
})))),
|
||||||
|
loc: None,
|
||||||
|
}))
|
||||||
|
},
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "max".into(),
|
name: "max".into(),
|
||||||
simple_name: "max".into(),
|
simple_name: "max".into(),
|
||||||
@ -1454,6 +1491,43 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||||
|
let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0);
|
||||||
|
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")];
|
||||||
|
let ret_ty = unifier.get_fresh_var(None, None);
|
||||||
|
|
||||||
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
|
name: "np_maximum".into(),
|
||||||
|
simple_name: "np_maximum".into(),
|
||||||
|
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: param_ty.iter().map(|p| FuncArg {
|
||||||
|
name: p.1.into(),
|
||||||
|
ty: p.0,
|
||||||
|
default_value: None,
|
||||||
|
}).collect(),
|
||||||
|
ret: ret_ty.0,
|
||||||
|
vars: [
|
||||||
|
(x1_ty.1, x1_ty.0),
|
||||||
|
(x2_ty.1, x2_ty.0),
|
||||||
|
(ret_ty.1, ret_ty.0),
|
||||||
|
].into_iter().collect(),
|
||||||
|
})),
|
||||||
|
var_id: vec![x1_ty.1, x2_ty.1],
|
||||||
|
instance_to_symbol: HashMap::default(),
|
||||||
|
instance_to_stmt: HashMap::default(),
|
||||||
|
resolver: None,
|
||||||
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_numpy_maximum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
|
})))),
|
||||||
|
loc: None,
|
||||||
|
}))
|
||||||
|
},
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "abs".into(),
|
name: "abs".into(),
|
||||||
simple_name: "abs".into(),
|
simple_name: "abs".into(),
|
||||||
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [224]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [238]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar213]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar213\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar227]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar227\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [226]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [240]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [231]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [245]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar212, typevar213]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar212\", \"typevar213\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar226, typevar227]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar226\", \"typevar227\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [232]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [246]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [240]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [254]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -951,6 +951,8 @@ impl<'a> Inferencer<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if [
|
if [
|
||||||
|
"np_minimum",
|
||||||
|
"np_maximum",
|
||||||
"np_arctan2",
|
"np_arctan2",
|
||||||
"np_copysign",
|
"np_copysign",
|
||||||
"np_fmax",
|
"np_fmax",
|
||||||
@ -959,8 +961,6 @@ impl<'a> Inferencer<'a> {
|
|||||||
"np_hypot",
|
"np_hypot",
|
||||||
"np_nextafter",
|
"np_nextafter",
|
||||||
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
|
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
|
||||||
let target_ty = self.primitives.float;
|
|
||||||
|
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
let arg0_ty = arg0.custom.unwrap();
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
let arg1 = self.fold_expr(args.remove(0))?;
|
let arg1 = self.fold_expr(args.remove(0))?;
|
||||||
@ -977,6 +977,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
} else {
|
} else {
|
||||||
arg1_ty
|
arg1_ty
|
||||||
};
|
};
|
||||||
|
|
||||||
let expected_arg1_dtype = if id == &"np_ldexp".into() {
|
let expected_arg1_dtype = if id == &"np_ldexp".into() {
|
||||||
self.primitives.int32
|
self.primitives.int32
|
||||||
} else {
|
} else {
|
||||||
@ -993,6 +994,12 @@ impl<'a> Inferencer<'a> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let target_ty = if id == &"np_minimum".into() || id == &"np_maximum".into() {
|
||||||
|
arg0_dtype
|
||||||
|
} else {
|
||||||
|
self.primitives.float
|
||||||
|
};
|
||||||
|
|
||||||
let ret = if [
|
let ret = if [
|
||||||
&arg0_ty,
|
&arg0_ty,
|
||||||
&arg1_ty,
|
&arg1_ty,
|
||||||
|
@ -175,7 +175,9 @@ def patch(module):
|
|||||||
module.np_isnan = np.isnan
|
module.np_isnan = np.isnan
|
||||||
module.np_isinf = np.isinf
|
module.np_isinf = np.isinf
|
||||||
module.np_min = np.min
|
module.np_min = np.min
|
||||||
|
module.np_minimum = np.minimum
|
||||||
module.np_max = np.max
|
module.np_max = np.max
|
||||||
|
module.np_maximum = np.maximum
|
||||||
module.np_sin = np.sin
|
module.np_sin = np.sin
|
||||||
module.np_cos = np.cos
|
module.np_cos = np.cos
|
||||||
module.np_exp = np.exp
|
module.np_exp = np.exp
|
||||||
|
@ -766,6 +766,42 @@ def test_ndarray_min():
|
|||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_float64(y)
|
output_float64(y)
|
||||||
|
|
||||||
|
def test_ndarray_minimum():
|
||||||
|
x = np_identity(2)
|
||||||
|
min_x_zeros = np_minimum(x, np_zeros([2]))
|
||||||
|
min_x_ones = np_minimum(x, np_zeros([2]))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(min_x_zeros)
|
||||||
|
output_ndarray_float_2(min_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_minimum_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
min_x_zeros = np_minimum(x, np_zeros([2]))
|
||||||
|
min_x_ones = np_minimum(x, np_zeros([2]))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(min_x_zeros)
|
||||||
|
output_ndarray_float_2(min_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_minimum_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
min_x_zeros = np_minimum(0.0, x)
|
||||||
|
min_x_ones = np_minimum(1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(min_x_zeros)
|
||||||
|
output_ndarray_float_2(min_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_minimum_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
min_x_zeros = np_minimum(x, 0.0)
|
||||||
|
min_x_ones = np_minimum(x, 1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(min_x_zeros)
|
||||||
|
output_ndarray_float_2(min_x_ones)
|
||||||
|
|
||||||
def test_ndarray_max():
|
def test_ndarray_max():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = np_max(x)
|
y = np_max(x)
|
||||||
@ -773,6 +809,42 @@ def test_ndarray_max():
|
|||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_float64(y)
|
output_float64(y)
|
||||||
|
|
||||||
|
def test_ndarray_maximum():
|
||||||
|
x = np_identity(2)
|
||||||
|
max_x_zeros = np_maximum(x, np_zeros([2]))
|
||||||
|
max_x_ones = np_maximum(x, np_zeros([2]))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(max_x_zeros)
|
||||||
|
output_ndarray_float_2(max_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_maximum_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
max_x_zeros = np_maximum(x, np_zeros([2]))
|
||||||
|
max_x_ones = np_maximum(x, np_zeros([2]))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(max_x_zeros)
|
||||||
|
output_ndarray_float_2(max_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_maximum_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
max_x_zeros = np_maximum(0.0, x)
|
||||||
|
max_x_ones = np_maximum(1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(max_x_zeros)
|
||||||
|
output_ndarray_float_2(max_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_maximum_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
max_x_zeros = np_maximum(x, 0.0)
|
||||||
|
max_x_ones = np_maximum(x, 1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(max_x_zeros)
|
||||||
|
output_ndarray_float_2(max_x_ones)
|
||||||
|
|
||||||
def test_ndarray_abs():
|
def test_ndarray_abs():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = abs(x)
|
y = abs(x)
|
||||||
@ -1378,7 +1450,15 @@ def run() -> int32:
|
|||||||
test_ndarray_round()
|
test_ndarray_round()
|
||||||
test_ndarray_floor()
|
test_ndarray_floor()
|
||||||
test_ndarray_min()
|
test_ndarray_min()
|
||||||
|
test_ndarray_minimum()
|
||||||
|
test_ndarray_minimum_broadcast()
|
||||||
|
test_ndarray_minimum_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_minimum_broadcast_rhs_scalar()
|
||||||
test_ndarray_max()
|
test_ndarray_max()
|
||||||
|
test_ndarray_maximum()
|
||||||
|
test_ndarray_maximum_broadcast()
|
||||||
|
test_ndarray_maximum_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_maximum_broadcast_rhs_scalar()
|
||||||
test_ndarray_abs()
|
test_ndarray_abs()
|
||||||
test_ndarray_isnan()
|
test_ndarray_isnan()
|
||||||
test_ndarray_isinf()
|
test_ndarray_isinf()
|
||||||
|
Loading…
Reference in New Issue
Block a user