Allow numpy functions to accept scalars or ndarrays #400
File diff suppressed because it is too large
Load Diff
|
@ -451,8 +451,6 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type());
|
|
||||||
|
|
||||||
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n",
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,17 +14,7 @@ use crate::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use itertools::{Itertools, izip};
|
use itertools::{Itertools, izip};
|
||||||
use nac3parser::ast::{
|
use nac3parser::ast::{self, fold::{self, Fold}, Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef};
|
||||||
self,
|
|
||||||
fold::{self, Fold},
|
|
||||||
Arguments,
|
|
||||||
Comprehension,
|
|
||||||
ExprContext,
|
|
||||||
ExprKind,
|
|
||||||
Located,
|
|
||||||
Location,
|
|
||||||
StrRef
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
@ -860,17 +850,194 @@ impl<'a> Inferencer<'a> {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
// int64 is special because its argument can be a constant larger than int32
|
|
||||||
if id == &"int64".into() && args.len() == 1 {
|
if [
|
||||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
"int32",
|
||||||
&args[0].node
|
"float",
|
||||||
{
|
"bool",
|
||||||
let custom = Some(self.primitives.int64);
|
"np_isnan",
|
||||||
let v: Result<i64, _> = (*val).try_into();
|
"np_isinf",
|
||||||
return if v.is_ok() {
|
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||||
|
let target_ty = if id == &"int32".into() {
|
||||||
|
self.primitives.int32
|
||||||
|
} else if id == &"float".into() {
|
||||||
|
self.primitives.float
|
||||||
|
} else if id == &"bool".into() || id == &"np_isnan".into() || id == &"np_isinf".into() {
|
||||||
|
self.primitives.bool
|
||||||
|
} else { unreachable!() };
|
||||||
|
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
|
||||||
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||||
|
|
||||||
|
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||||
|
} else {
|
||||||
|
target_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "n".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if [
|
||||||
|
"np_arctan2",
|
||||||
|
"np_copysign",
|
||||||
|
"np_fmax",
|
||||||
|
"np_fmin",
|
||||||
|
"np_ldexp",
|
||||||
|
"np_hypot",
|
||||||
|
"np_nextafter",
|
||||||
|
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
|
||||||
|
let target_ty = self.primitives.float;
|
||||||
|
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
let arg1 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg1_ty = arg1.custom.unwrap();
|
||||||
|
|
||||||
|
let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
|
||||||
|
} else {
|
||||||
|
arg0_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
|
||||||
|
} else {
|
||||||
|
arg1_ty
|
||||||
|
};
|
||||||
|
let expected_arg1_dtype = if id == &"np_ldexp".into() {
|
||||||
|
self.primitives.int32
|
||||||
|
} else {
|
||||||
|
arg0_dtype
|
||||||
|
};
|
||||||
|
if !self.unifier.unioned(arg1_dtype, expected_arg1_dtype) {
|
||||||
|
return report_error(
|
||||||
|
format!(
|
||||||
|
"Expected {} for second argument of {id}, got {}",
|
||||||
|
self.unifier.stringify(expected_arg1_dtype),
|
||||||
|
self.unifier.stringify(arg1_dtype),
|
||||||
|
).as_str(),
|
||||||
|
arg0.location,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let ret = if [
|
||||||
|
&arg0_ty,
|
||||||
|
&arg1_ty,
|
||||||
|
].into_iter().any(|arg_ty| arg_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
|
||||||
|
// typeof_ndarray_broadcast requires both dtypes to be the same, but ldexp accepts
|
||||||
|
// (float, int32), so convert it to align with the dtype of the first arg
|
||||||
|
let arg1_ty = if id == &"np_ldexp".into() {
|
||||||
|
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
|
||||||
|
|
||||||
|
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
|
||||||
|
} else {
|
||||||
|
target_ty
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
arg1_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
match typeof_ndarray_broadcast(self.unifier, self.primitives, arg0_ty, arg1_ty) {
|
||||||
|
Ok(broadcasted_ty) => broadcasted_ty,
|
||||||
|
Err(err) => return report_error(err.as_str(), location),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
target_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "x1".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
FuncArg {
|
||||||
|
name: "x2".into(),
|
||||||
|
ty: arg1.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, arg1],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// int64, uint32 and uint64 are special because their argument can be a constant outside the
|
||||||
|
// range of int32s
|
||||||
|
if [
|
||||||
|
"int64",
|
||||||
|
"uint32",
|
||||||
|
"uint64",
|
||||||
|
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||||
|
let target_ty = if id == &"int64".into() {
|
||||||
|
self.primitives.int64
|
||||||
|
} else if id == &"uint32".into() {
|
||||||
|
self.primitives.uint32
|
||||||
|
} else if id == &"uint64".into() {
|
||||||
|
self.primitives.uint64
|
||||||
|
} else { unreachable!() };
|
||||||
|
|
||||||
|
// Handle constants first to ensure that their types are not defaulted to int32, which
|
||||||
|
// causes an "Integer out of bound" error
|
||||||
|
if let ExprKind::Constant {
|
||||||
|
value: ast::Constant::Int(val),
|
||||||
|
kind
|
||||||
|
} = &args[0].node {
|
||||||
|
let conv_is_ok = if self.unifier.unioned(target_ty, self.primitives.int64) {
|
||||||
|
i64::try_from(*val).is_ok()
|
||||||
|
} else if self.unifier.unioned(target_ty, self.primitives.uint32) {
|
||||||
|
u32::try_from(*val).is_ok()
|
||||||
|
} else if self.unifier.unioned(target_ty, self.primitives.uint64) {
|
||||||
|
u64::try_from(*val).is_ok()
|
||||||
|
} else { unreachable!() };
|
||||||
|
|
||||||
|
return if conv_is_ok {
|
||||||
Ok(Some(Located {
|
Ok(Some(Located {
|
||||||
location: args[0].location,
|
location: args[0].location,
|
||||||
custom,
|
custom: Some(target_ty),
|
||||||
node: ExprKind::Constant {
|
node: ExprKind::Constant {
|
||||||
value: ast::Constant::Int(*val),
|
value: ast::Constant::Int(*val),
|
||||||
kind: kind.clone(),
|
kind: kind.clone(),
|
||||||
|
@ -880,46 +1047,43 @@ impl<'a> Inferencer<'a> {
|
||||||
report_error("Integer out of bound", args[0].location)
|
report_error("Integer out of bound", args[0].location)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if id == &"uint32".into() && args.len() == 1 {
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
&args[0].node
|
|
||||||
{
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
let custom = Some(self.primitives.uint32);
|
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||||
let v: Result<u32, _> = (*val).try_into();
|
|
||||||
return if v.is_ok() {
|
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||||
Ok(Some(Located {
|
} else {
|
||||||
location: args[0].location,
|
target_ty
|
||||||
custom,
|
};
|
||||||
node: ExprKind::Constant {
|
|
||||||
value: ast::Constant::Int(*val),
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
kind: kind.clone(),
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "n".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0],
|
||||||
|
keywords: vec![],
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
} else {
|
|
||||||
report_error("Integer out of bound", args[0].location)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if id == &"uint64".into() && args.len() == 1 {
|
|
||||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
|
||||||
&args[0].node
|
|
||||||
{
|
|
||||||
let custom = Some(self.primitives.uint64);
|
|
||||||
let v: Result<u64, _> = (*val).try_into();
|
|
||||||
return if v.is_ok() {
|
|
||||||
Ok(Some(Located {
|
|
||||||
location: args[0].location,
|
|
||||||
custom,
|
|
||||||
node: ExprKind::Constant {
|
|
||||||
value: ast::Constant::Int(*val),
|
|
||||||
kind: kind.clone(),
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
report_error("Integer out of bound", args[0].location)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1-argument ndarray n-dimensional creation functions
|
// 1-argument ndarray n-dimensional creation functions
|
||||||
|
|
|
@ -58,12 +58,39 @@ class _NDArrayDummy(Generic[T, N]):
|
||||||
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
|
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
|
||||||
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
|
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
|
||||||
|
|
||||||
|
def _bool(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.bool_(x)
|
||||||
|
else:
|
||||||
|
return bool(x)
|
||||||
|
|
||||||
|
def _float(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.float_(x)
|
||||||
|
else:
|
||||||
|
return float(x)
|
||||||
|
|
||||||
def round_away_zero(x):
|
def round_away_zero(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.vectorize(round_away_zero)(x)
|
||||||
|
else:
|
||||||
if x >= 0.0:
|
if x >= 0.0:
|
||||||
return math.floor(x + 0.5)
|
return math.floor(x + 0.5)
|
||||||
else:
|
else:
|
||||||
return math.ceil(x - 0.5)
|
return math.ceil(x - 0.5)
|
||||||
|
|
||||||
|
def _floor(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.vectorize(_floor)(x)
|
||||||
|
else:
|
||||||
|
return math.floor(x)
|
||||||
|
|
||||||
|
def _ceil(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.vectorize(_ceil)(x)
|
||||||
|
else:
|
||||||
|
return math.ceil(x)
|
||||||
|
|
||||||
def patch(module):
|
def patch(module):
|
||||||
def dbl_nan():
|
def dbl_nan():
|
||||||
return np.nan
|
return np.nan
|
||||||
|
@ -112,6 +139,8 @@ def patch(module):
|
||||||
module.int64 = int64
|
module.int64 = int64
|
||||||
module.uint32 = uint32
|
module.uint32 = uint32
|
||||||
module.uint64 = uint64
|
module.uint64 = uint64
|
||||||
|
module.bool = _bool
|
||||||
|
module.float = _float
|
||||||
module.TypeVar = TypeVar
|
module.TypeVar = TypeVar
|
||||||
module.ConstGeneric = ConstGeneric
|
module.ConstGeneric = ConstGeneric
|
||||||
module.Generic = Generic
|
module.Generic = Generic
|
||||||
|
@ -125,11 +154,11 @@ def patch(module):
|
||||||
module.round = round_away_zero
|
module.round = round_away_zero
|
||||||
module.round64 = round_away_zero
|
module.round64 = round_away_zero
|
||||||
module.np_round = np.round
|
module.np_round = np.round
|
||||||
module.floor = math.floor
|
module.floor = _floor
|
||||||
module.floor64 = math.floor
|
module.floor64 = _floor
|
||||||
module.np_floor = np.floor
|
module.np_floor = np.floor
|
||||||
module.ceil = math.ceil
|
module.ceil = _ceil
|
||||||
module.ceil64 = math.ceil
|
module.ceil64 = _ceil
|
||||||
module.np_ceil = np.ceil
|
module.np_ceil = np.ceil
|
||||||
|
|
||||||
# NumPy ndarray functions
|
# NumPy ndarray functions
|
||||||
|
|
|
@ -1,3 +1,11 @@
|
||||||
|
@extern
|
||||||
|
def dbl_nan() -> float:
|
||||||
|
...
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def dbl_inf() -> float:
|
||||||
|
...
|
||||||
|
|
||||||
@extern
|
@extern
|
||||||
def output_bool(x: bool):
|
def output_bool(x: bool):
|
||||||
...
|
...
|
||||||
|
@ -6,6 +14,18 @@ def output_bool(x: bool):
|
||||||
def output_int32(x: int32):
|
def output_int32(x: int32):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def output_int64(x: int64):
|
||||||
|
...
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def output_uint32(x: uint32):
|
||||||
|
...
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def output_uint64(x: uint64):
|
||||||
|
...
|
||||||
|
|
||||||
@extern
|
@extern
|
||||||
def output_float64(x: float):
|
def output_float64(x: float):
|
||||||
...
|
...
|
||||||
|
@ -24,6 +44,21 @@ def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]):
|
||||||
for c in range(len(n[r])):
|
for c in range(len(n[r])):
|
||||||
output_int32(n[r][c])
|
output_int32(n[r][c])
|
||||||
|
|
||||||
|
def output_ndarray_int64_2(n: ndarray[int64, Literal[2]]):
|
||||||
|
for r in range(len(n)):
|
||||||
|
for c in range(len(n[r])):
|
||||||
|
output_int64(n[r][c])
|
||||||
|
|
||||||
|
def output_ndarray_uint32_2(n: ndarray[uint32, Literal[2]]):
|
||||||
|
for r in range(len(n)):
|
||||||
|
for c in range(len(n[r])):
|
||||||
|
output_uint32(n[r][c])
|
||||||
|
|
||||||
|
def output_ndarray_uint64_2(n: ndarray[uint64, Literal[2]]):
|
||||||
|
for r in range(len(n)):
|
||||||
|
for c in range(len(n[r])):
|
||||||
|
output_uint64(n[r][c])
|
||||||
|
|
||||||
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
|
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
|
||||||
for i in range(len(n)):
|
for i in range(len(n)):
|
||||||
output_float64(n[i])
|
output_float64(n[i])
|
||||||
|
@ -649,6 +684,586 @@ def test_ndarray_ge_broadcast_rhs_scalar():
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_bool_2(y)
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_int32():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = int32(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_int32_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_int64():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = int64(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_int64_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_uint32():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = uint32(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_uint32_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_uint64():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = uint64(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_uint64_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_float():
|
||||||
|
x = np_full([2, 2], 1)
|
||||||
|
y = float(x)
|
||||||
|
|
||||||
|
output_ndarray_int32_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_bool():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = bool(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_round():
|
||||||
|
x = np_identity(2)
|
||||||
|
xf32 = round(x)
|
||||||
|
xf64 = round64(x)
|
||||||
|
xff = np_round(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_int32_2(xf32)
|
||||||
|
output_ndarray_int64_2(xf64)
|
||||||
|
output_ndarray_float_2(xff)
|
||||||
|
|
||||||
|
def test_ndarray_floor():
|
||||||
|
x = np_identity(2)
|
||||||
|
xf32 = floor(x)
|
||||||
|
xf64 = floor64(x)
|
||||||
|
xff = np_floor(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_int32_2(xf32)
|
||||||
|
output_ndarray_int64_2(xf64)
|
||||||
|
output_ndarray_float_2(xff)
|
||||||
|
|
||||||
|
def test_ndarray_ceil():
|
||||||
|
x = np_identity(2)
|
||||||
|
xf32 = ceil(x)
|
||||||
|
xf64 = ceil64(x)
|
||||||
|
xff = np_ceil(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_int32_2(xf32)
|
||||||
|
output_ndarray_int64_2(xf64)
|
||||||
|
output_ndarray_float_2(xff)
|
||||||
|
|
||||||
|
def test_ndarray_abs():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = abs(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_isnan():
|
||||||
|
x = np_identity(2)
|
||||||
|
x_isnan = np_isnan(x)
|
||||||
|
y = np_full([2, 2], dbl_nan())
|
||||||
|
y_isnan = np_isnan(y)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(x_isnan)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
output_ndarray_bool_2(y_isnan)
|
||||||
|
|
||||||
|
def test_ndarray_isinf():
|
||||||
|
x = np_identity(2)
|
||||||
|
x_isinf = np_isinf(x)
|
||||||
|
y = np_full([2, 2], dbl_inf())
|
||||||
|
y_isinf = np_isinf(y)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(x_isinf)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
output_ndarray_bool_2(y_isinf)
|
||||||
|
|
||||||
|
def test_ndarray_sin():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_sin(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_cos():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_cos(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_exp():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_exp(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_exp2():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_exp2(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_log():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_log(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_log10():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_log10(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_log2():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_log2(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_fabs():
|
||||||
|
x = -np_identity(2)
|
||||||
|
y = np_fabs(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_sqrt():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_sqrt(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_rint():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_rint(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_tan():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_tan(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_arcsin():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_arcsin(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_arccos():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_arccos(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_arctan():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_arctan(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_sinh():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_sinh(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_cosh():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_cosh(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_tanh():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_tanh(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_arcsinh():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_arcsinh(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_arccosh():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_arccosh(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_arctanh():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_arctanh(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_expm1():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_expm1(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_cbrt():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = np_cbrt(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_erf():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = sp_spec_erf(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_erfc():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = sp_spec_erfc(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_gamma():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = sp_spec_gamma(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_gammaln():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = sp_spec_gammaln(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_j0():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = sp_spec_j0(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_j1():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = sp_spec_j1(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_arctan2():
|
||||||
|
x = np_identity(2)
|
||||||
|
zeros = np_zeros([2, 2])
|
||||||
|
ones = np_ones([2, 2])
|
||||||
|
atan2_x_zeros = np_arctan2(x, zeros)
|
||||||
|
atan2_x_ones = np_arctan2(x, ones)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(zeros)
|
||||||
|
output_ndarray_float_2(ones)
|
||||||
|
output_ndarray_float_2(atan2_x_zeros)
|
||||||
|
output_ndarray_float_2(atan2_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_arctan2_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
atan2_x_zeros = np_arctan2(x, np_zeros([2]))
|
||||||
|
atan2_x_ones = np_arctan2(x, np_ones([2]))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(atan2_x_zeros)
|
||||||
|
output_ndarray_float_2(atan2_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_arctan2_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
atan2_x_zeros = np_arctan2(0.0, x)
|
||||||
|
atan2_x_ones = np_arctan2(1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(atan2_x_zeros)
|
||||||
|
output_ndarray_float_2(atan2_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_arctan2_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
atan2_x_zeros = np_arctan2(x, 0.0)
|
||||||
|
atan2_x_ones = np_arctan2(x, 1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(atan2_x_zeros)
|
||||||
|
output_ndarray_float_2(atan2_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_copysign():
|
||||||
|
x = np_identity(2)
|
||||||
|
ones = np_ones([2, 2])
|
||||||
|
negones = np_full([2, 2], -1.0)
|
||||||
|
copysign_x_ones = np_copysign(x, ones)
|
||||||
|
copysign_x_negones = np_copysign(x, ones)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(ones)
|
||||||
|
output_ndarray_float_2(negones)
|
||||||
|
output_ndarray_float_2(copysign_x_ones)
|
||||||
|
output_ndarray_float_2(copysign_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_copysign_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
copysign_x_ones = np_copysign(x, np_ones([2]))
|
||||||
|
copysign_x_negones = np_copysign(x, np_full([2], -1.0))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(copysign_x_ones)
|
||||||
|
output_ndarray_float_2(copysign_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_copysign_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
copysign_x_ones = np_copysign(1.0, x)
|
||||||
|
copysign_x_negones = np_copysign(-1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(copysign_x_ones)
|
||||||
|
output_ndarray_float_2(copysign_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_copysign_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
copysign_x_ones = np_copysign(x, 1.0)
|
||||||
|
copysign_x_negones = np_copysign(x, -1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(copysign_x_ones)
|
||||||
|
output_ndarray_float_2(copysign_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmax():
|
||||||
|
x = np_identity(2)
|
||||||
|
ones = np_ones([2, 2])
|
||||||
|
negones = np_full([2, 2], -1.0)
|
||||||
|
fmax_x_ones = np_fmax(x, ones)
|
||||||
|
fmax_x_negones = np_fmax(x, ones)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(ones)
|
||||||
|
output_ndarray_float_2(negones)
|
||||||
|
output_ndarray_float_2(fmax_x_ones)
|
||||||
|
output_ndarray_float_2(fmax_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmax_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
fmax_x_ones = np_fmax(x, np_ones([2]))
|
||||||
|
fmax_x_negones = np_fmax(x, np_full([2], -1.0))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(fmax_x_ones)
|
||||||
|
output_ndarray_float_2(fmax_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmax_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
fmax_x_ones = np_fmax(1.0, x)
|
||||||
|
fmax_x_negones = np_fmax(-1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(fmax_x_ones)
|
||||||
|
output_ndarray_float_2(fmax_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmax_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
fmax_x_ones = np_fmax(x, 1.0)
|
||||||
|
fmax_x_negones = np_fmax(x, -1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(fmax_x_ones)
|
||||||
|
output_ndarray_float_2(fmax_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmin():
|
||||||
|
x = np_identity(2)
|
||||||
|
ones = np_ones([2, 2])
|
||||||
|
negones = np_full([2, 2], -1.0)
|
||||||
|
fmin_x_ones = np_fmin(x, ones)
|
||||||
|
fmin_x_negones = np_fmin(x, ones)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(ones)
|
||||||
|
output_ndarray_float_2(negones)
|
||||||
|
output_ndarray_float_2(fmin_x_ones)
|
||||||
|
output_ndarray_float_2(fmin_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmin_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
fmin_x_ones = np_fmin(x, np_ones([2]))
|
||||||
|
fmin_x_negones = np_fmin(x, np_full([2], -1.0))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(fmin_x_ones)
|
||||||
|
output_ndarray_float_2(fmin_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmin_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
fmin_x_ones = np_fmin(1.0, x)
|
||||||
|
fmin_x_negones = np_fmin(-1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(fmin_x_ones)
|
||||||
|
output_ndarray_float_2(fmin_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_fmin_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
fmin_x_ones = np_fmin(x, 1.0)
|
||||||
|
fmin_x_negones = np_fmin(x, -1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(fmin_x_ones)
|
||||||
|
output_ndarray_float_2(fmin_x_negones)
|
||||||
|
|
||||||
|
def test_ndarray_ldexp():
|
||||||
|
x = np_identity(2)
|
||||||
|
zeros = np_full([2, 2], 0)
|
||||||
|
ones = np_full([2, 2], 1)
|
||||||
|
ldexp_x_zeros = np_ldexp(x, zeros)
|
||||||
|
ldexp_x_ones = np_ldexp(x, ones)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_int32_2(zeros)
|
||||||
|
output_ndarray_int32_2(ones)
|
||||||
|
output_ndarray_float_2(ldexp_x_zeros)
|
||||||
|
output_ndarray_float_2(ldexp_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_ldexp_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
ldexp_x_zeros = np_ldexp(x, np_full([2], 0))
|
||||||
|
ldexp_x_ones = np_ldexp(x, np_full([2], 1))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(ldexp_x_zeros)
|
||||||
|
output_ndarray_float_2(ldexp_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_ldexp_broadcast_lhs_scalar():
|
||||||
|
x = int32(np_identity(2))
|
||||||
|
ldexp_x_zeros = np_ldexp(0.0, x)
|
||||||
|
ldexp_x_ones = np_ldexp(1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_int32_2(x)
|
||||||
|
output_ndarray_float_2(ldexp_x_zeros)
|
||||||
|
output_ndarray_float_2(ldexp_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_ldexp_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
ldexp_x_zeros = np_ldexp(x, 0)
|
||||||
|
ldexp_x_ones = np_ldexp(x, 1)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(ldexp_x_zeros)
|
||||||
|
output_ndarray_float_2(ldexp_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_hypot():
|
||||||
|
x = np_identity(2)
|
||||||
|
zeros = np_zeros([2, 2])
|
||||||
|
ones = np_ones([2, 2])
|
||||||
|
hypot_x_zeros = np_hypot(x, zeros)
|
||||||
|
hypot_x_ones = np_hypot(x, ones)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(zeros)
|
||||||
|
output_ndarray_float_2(ones)
|
||||||
|
output_ndarray_float_2(hypot_x_zeros)
|
||||||
|
output_ndarray_float_2(hypot_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_hypot_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
hypot_x_zeros = np_hypot(x, np_zeros([2]))
|
||||||
|
hypot_x_ones = np_hypot(x, np_ones([2]))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(hypot_x_zeros)
|
||||||
|
output_ndarray_float_2(hypot_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_hypot_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
hypot_x_zeros = np_hypot(0.0, x)
|
||||||
|
hypot_x_ones = np_hypot(1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(hypot_x_zeros)
|
||||||
|
output_ndarray_float_2(hypot_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_hypot_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
hypot_x_zeros = np_hypot(x, 0.0)
|
||||||
|
hypot_x_ones = np_hypot(x, 1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(hypot_x_zeros)
|
||||||
|
output_ndarray_float_2(hypot_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_nextafter():
|
||||||
|
x = np_identity(2)
|
||||||
|
zeros = np_zeros([2, 2])
|
||||||
|
ones = np_ones([2, 2])
|
||||||
|
nextafter_x_zeros = np_nextafter(x, zeros)
|
||||||
|
nextafter_x_ones = np_nextafter(x, ones)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(zeros)
|
||||||
|
output_ndarray_float_2(ones)
|
||||||
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_nextafter_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
nextafter_x_zeros = np_nextafter(x, np_zeros([2]))
|
||||||
|
nextafter_x_ones = np_nextafter(x, np_ones([2]))
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_nextafter_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
nextafter_x_zeros = np_nextafter(0.0, x)
|
||||||
|
nextafter_x_ones = np_nextafter(1.0, x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
nextafter_x_zeros = np_nextafter(x, 0.0)
|
||||||
|
nextafter_x_ones = np_nextafter(x, 1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
|
@ -739,4 +1354,76 @@ def run() -> int32:
|
||||||
test_ndarray_ge_broadcast_lhs_scalar()
|
test_ndarray_ge_broadcast_lhs_scalar()
|
||||||
test_ndarray_ge_broadcast_rhs_scalar()
|
test_ndarray_ge_broadcast_rhs_scalar()
|
||||||
|
|
||||||
|
test_ndarray_int32()
|
||||||
|
test_ndarray_int64()
|
||||||
|
test_ndarray_uint32()
|
||||||
|
test_ndarray_uint64()
|
||||||
|
test_ndarray_float()
|
||||||
|
test_ndarray_bool()
|
||||||
|
|
||||||
|
test_ndarray_round()
|
||||||
|
test_ndarray_floor()
|
||||||
|
test_ndarray_abs()
|
||||||
|
test_ndarray_isnan()
|
||||||
|
test_ndarray_isinf()
|
||||||
|
|
||||||
|
test_ndarray_sin()
|
||||||
|
test_ndarray_cos()
|
||||||
|
test_ndarray_exp()
|
||||||
|
test_ndarray_exp2()
|
||||||
|
test_ndarray_log()
|
||||||
|
test_ndarray_log10()
|
||||||
|
test_ndarray_log2()
|
||||||
|
test_ndarray_fabs()
|
||||||
|
test_ndarray_sqrt()
|
||||||
|
test_ndarray_rint()
|
||||||
|
test_ndarray_tan()
|
||||||
|
test_ndarray_arcsin()
|
||||||
|
test_ndarray_arccos()
|
||||||
|
test_ndarray_arctan()
|
||||||
|
test_ndarray_sinh()
|
||||||
|
test_ndarray_cosh()
|
||||||
|
test_ndarray_tanh()
|
||||||
|
test_ndarray_arcsinh()
|
||||||
|
test_ndarray_arccosh()
|
||||||
|
test_ndarray_arctanh()
|
||||||
|
test_ndarray_expm1()
|
||||||
|
test_ndarray_cbrt()
|
||||||
|
|
||||||
|
test_ndarray_erf()
|
||||||
|
test_ndarray_erfc()
|
||||||
|
test_ndarray_gamma()
|
||||||
|
test_ndarray_gammaln()
|
||||||
|
test_ndarray_j0()
|
||||||
|
test_ndarray_j1()
|
||||||
|
|
||||||
|
test_ndarray_arctan2()
|
||||||
|
test_ndarray_arctan2_broadcast()
|
||||||
|
test_ndarray_arctan2_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_arctan2_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_copysign()
|
||||||
|
test_ndarray_copysign_broadcast()
|
||||||
|
test_ndarray_copysign_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_copysign_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_fmax()
|
||||||
|
test_ndarray_fmax_broadcast()
|
||||||
|
test_ndarray_fmax_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_fmax_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_fmin()
|
||||||
|
test_ndarray_fmin_broadcast()
|
||||||
|
test_ndarray_fmin_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_fmin_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_ldexp()
|
||||||
|
test_ndarray_ldexp_broadcast()
|
||||||
|
test_ndarray_ldexp_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_ldexp_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_hypot()
|
||||||
|
test_ndarray_hypot_broadcast()
|
||||||
|
test_ndarray_hypot_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_hypot_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_nextafter()
|
||||||
|
test_ndarray_nextafter_broadcast()
|
||||||
|
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue