Allow numpy functions to accept scalars or ndarrays #400

Merged
sb10q merged 1 commits from enhance/issue-149-ndarray/numpy-func-take-ndarray into master 2024-08-17 17:37:21 +08:00
11 changed files with 3521 additions and 1015 deletions

File diff suppressed because it is too large Load Diff

View File

@ -451,8 +451,6 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
} }
}; };
debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type());
value_fn(generator, ctx, (lhs_elem, rhs_elem)) value_fn(generator, ctx, (lhs_elem, rhs_elem))
}, },
)?; )?;

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n",
] ]

View File

@ -14,17 +14,7 @@ use crate::{
}, },
}; };
use itertools::{Itertools, izip}; use itertools::{Itertools, izip};
use nac3parser::ast::{ use nac3parser::ast::{self, fold::{self, Fold}, Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef};
self,
fold::{self, Fold},
Arguments,
Comprehension,
ExprContext,
ExprKind,
Located,
Location,
StrRef
};
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -860,17 +850,194 @@ impl<'a> Inferencer<'a> {
}, },
})) }))
} }
// int64 is special because its argument can be a constant larger than int32
if id == &"int64".into() && args.len() == 1 { if [
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = "int32",
&args[0].node "float",
{ "bool",
let custom = Some(self.primitives.int64); "np_isnan",
let v: Result<i64, _> = (*val).try_into(); "np_isinf",
return if v.is_ok() { ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let target_ty = if id == &"int32".into() {
self.primitives.int32
} else if id == &"float".into() {
self.primitives.float
} else if id == &"bool".into() || id == &"np_isnan".into() || id == &"np_isinf".into() {
self.primitives.bool
} else { unreachable!() };
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
} else {
target_ty
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords: vec![],
},
}))
}
if [
"np_arctan2",
"np_copysign",
"np_fmax",
"np_fmin",
"np_ldexp",
"np_hypot",
"np_nextafter",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
let target_ty = self.primitives.float;
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let arg1 = self.fold_expr(args.remove(0))?;
let arg1_ty = arg1.custom.unwrap();
let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
} else {
arg0_ty
};
let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
} else {
arg1_ty
};
let expected_arg1_dtype = if id == &"np_ldexp".into() {
self.primitives.int32
} else {
arg0_dtype
};
if !self.unifier.unioned(arg1_dtype, expected_arg1_dtype) {
return report_error(
format!(
"Expected {} for second argument of {id}, got {}",
self.unifier.stringify(expected_arg1_dtype),
self.unifier.stringify(arg1_dtype),
).as_str(),
arg0.location,
)
}
let ret = if [
&arg0_ty,
&arg1_ty,
].into_iter().any(|arg_ty| arg_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
// typeof_ndarray_broadcast requires both dtypes to be the same, but ldexp accepts
// (float, int32), so convert it to align with the dtype of the first arg
let arg1_ty = if id == &"np_ldexp".into() {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
} else {
target_ty
}
} else {
arg1_ty
};
match typeof_ndarray_broadcast(self.unifier, self.primitives, arg0_ty, arg1_ty) {
Ok(broadcasted_ty) => broadcasted_ty,
Err(err) => return report_error(err.as_str(), location),
}
} else {
target_ty
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "x1".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
FuncArg {
name: "x2".into(),
ty: arg1.custom.unwrap(),
default_value: None,
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0, arg1],
keywords: vec![],
},
}))
}
// int64, uint32 and uint64 are special because their argument can be a constant outside the
// range of int32s
if [
"int64",
"uint32",
"uint64",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let target_ty = if id == &"int64".into() {
self.primitives.int64
} else if id == &"uint32".into() {
self.primitives.uint32
} else if id == &"uint64".into() {
self.primitives.uint64
} else { unreachable!() };
// Handle constants first to ensure that their types are not defaulted to int32, which
// causes an "Integer out of bound" error
if let ExprKind::Constant {
value: ast::Constant::Int(val),
kind
} = &args[0].node {
let conv_is_ok = if self.unifier.unioned(target_ty, self.primitives.int64) {
i64::try_from(*val).is_ok()
} else if self.unifier.unioned(target_ty, self.primitives.uint32) {
u32::try_from(*val).is_ok()
} else if self.unifier.unioned(target_ty, self.primitives.uint64) {
u64::try_from(*val).is_ok()
} else { unreachable!() };
return if conv_is_ok {
Ok(Some(Located { Ok(Some(Located {
location: args[0].location, location: args[0].location,
custom, custom: Some(target_ty),
node: ExprKind::Constant { node: ExprKind::Constant {
value: ast::Constant::Int(*val), value: ast::Constant::Int(*val),
kind: kind.clone(), kind: kind.clone(),
@ -880,46 +1047,43 @@ impl<'a> Inferencer<'a> {
report_error("Integer out of bound", args[0].location) report_error("Integer out of bound", args[0].location)
} }
} }
}
if id == &"uint32".into() && args.len() == 1 { let arg0 = self.fold_expr(args.remove(0))?;
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = let arg0_ty = arg0.custom.unwrap();
&args[0].node
{ let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let custom = Some(self.primitives.uint32); let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
let v: Result<u32, _> = (*val).try_into();
return if v.is_ok() { make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
Ok(Some(Located { } else {
location: args[0].location, target_ty
custom, };
node: ExprKind::Constant {
value: ast::Constant::Int(*val), let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
kind: kind.clone(), args: vec![
FuncArg {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
ret,
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords: vec![],
}, },
})) }))
} else {
report_error("Integer out of bound", args[0].location)
}
}
}
if id == &"uint64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let custom = Some(self.primitives.uint64);
let v: Result<u64, _> = (*val).try_into();
return if v.is_ok() {
Ok(Some(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(*val),
kind: kind.clone(),
},
}))
} else {
report_error("Integer out of bound", args[0].location)
}
}
} }
// 1-argument ndarray n-dimensional creation functions // 1-argument ndarray n-dimensional creation functions

View File

@ -58,12 +58,39 @@ class _NDArrayDummy(Generic[T, N]):
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic # https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]] NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
def _bool(x):
if isinstance(x, np.ndarray):
return np.bool_(x)
else:
return bool(x)
def _float(x):
if isinstance(x, np.ndarray):
return np.float_(x)
else:
return float(x)
def round_away_zero(x): def round_away_zero(x):
if isinstance(x, np.ndarray):
return np.vectorize(round_away_zero)(x)
else:
if x >= 0.0: if x >= 0.0:
return math.floor(x + 0.5) return math.floor(x + 0.5)
else: else:
return math.ceil(x - 0.5) return math.ceil(x - 0.5)
def _floor(x):
if isinstance(x, np.ndarray):
return np.vectorize(_floor)(x)
else:
return math.floor(x)
def _ceil(x):
if isinstance(x, np.ndarray):
return np.vectorize(_ceil)(x)
else:
return math.ceil(x)
def patch(module): def patch(module):
def dbl_nan(): def dbl_nan():
return np.nan return np.nan
@ -112,6 +139,8 @@ def patch(module):
module.int64 = int64 module.int64 = int64
module.uint32 = uint32 module.uint32 = uint32
module.uint64 = uint64 module.uint64 = uint64
module.bool = _bool
module.float = _float
module.TypeVar = TypeVar module.TypeVar = TypeVar
module.ConstGeneric = ConstGeneric module.ConstGeneric = ConstGeneric
module.Generic = Generic module.Generic = Generic
@ -125,11 +154,11 @@ def patch(module):
module.round = round_away_zero module.round = round_away_zero
module.round64 = round_away_zero module.round64 = round_away_zero
module.np_round = np.round module.np_round = np.round
module.floor = math.floor module.floor = _floor
module.floor64 = math.floor module.floor64 = _floor
module.np_floor = np.floor module.np_floor = np.floor
module.ceil = math.ceil module.ceil = _ceil
module.ceil64 = math.ceil module.ceil64 = _ceil
module.np_ceil = np.ceil module.np_ceil = np.ceil
# NumPy ndarray functions # NumPy ndarray functions

View File

@ -1,3 +1,11 @@
@extern
def dbl_nan() -> float:
...
@extern
def dbl_inf() -> float:
...
@extern @extern
def output_bool(x: bool): def output_bool(x: bool):
... ...
@ -6,6 +14,18 @@ def output_bool(x: bool):
def output_int32(x: int32): def output_int32(x: int32):
... ...
@extern
def output_int64(x: int64):
...
@extern
def output_uint32(x: uint32):
...
@extern
def output_uint64(x: uint64):
...
@extern @extern
def output_float64(x: float): def output_float64(x: float):
... ...
@ -24,6 +44,21 @@ def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]):
for c in range(len(n[r])): for c in range(len(n[r])):
output_int32(n[r][c]) output_int32(n[r][c])
def output_ndarray_int64_2(n: ndarray[int64, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_int64(n[r][c])
def output_ndarray_uint32_2(n: ndarray[uint32, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_uint32(n[r][c])
def output_ndarray_uint64_2(n: ndarray[uint64, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_uint64(n[r][c])
def output_ndarray_float_1(n: ndarray[float, Literal[1]]): def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
for i in range(len(n)): for i in range(len(n)):
output_float64(n[i]) output_float64(n[i])
@ -649,6 +684,586 @@ def test_ndarray_ge_broadcast_rhs_scalar():
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_bool_2(y) output_ndarray_bool_2(y)
def test_ndarray_int32():
x = np_identity(2)
y = int32(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(y)
def test_ndarray_int64():
x = np_identity(2)
y = int64(x)
output_ndarray_float_2(x)
output_ndarray_int64_2(y)
def test_ndarray_uint32():
x = np_identity(2)
y = uint32(x)
output_ndarray_float_2(x)
output_ndarray_uint32_2(y)
def test_ndarray_uint64():
x = np_identity(2)
y = uint64(x)
output_ndarray_float_2(x)
output_ndarray_uint64_2(y)
def test_ndarray_float():
x = np_full([2, 2], 1)
y = float(x)
output_ndarray_int32_2(x)
output_ndarray_float_2(y)
def test_ndarray_bool():
x = np_identity(2)
y = bool(x)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_round():
x = np_identity(2)
xf32 = round(x)
xf64 = round64(x)
xff = np_round(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(xf32)
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def test_ndarray_floor():
x = np_identity(2)
xf32 = floor(x)
xf64 = floor64(x)
xff = np_floor(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(xf32)
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def test_ndarray_ceil():
x = np_identity(2)
xf32 = ceil(x)
xf64 = ceil64(x)
xff = np_ceil(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(xf32)
output_ndarray_int64_2(xf64)
output_ndarray_float_2(xff)
def test_ndarray_abs():
x = np_identity(2)
y = abs(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_isnan():
x = np_identity(2)
x_isnan = np_isnan(x)
y = np_full([2, 2], dbl_nan())
y_isnan = np_isnan(y)
output_ndarray_float_2(x)
output_ndarray_bool_2(x_isnan)
output_ndarray_float_2(y)
output_ndarray_bool_2(y_isnan)
def test_ndarray_isinf():
x = np_identity(2)
x_isinf = np_isinf(x)
y = np_full([2, 2], dbl_inf())
y_isinf = np_isinf(y)
output_ndarray_float_2(x)
output_ndarray_bool_2(x_isinf)
output_ndarray_float_2(y)
output_ndarray_bool_2(y_isinf)
def test_ndarray_sin():
x = np_identity(2)
y = np_sin(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_cos():
x = np_identity(2)
y = np_cos(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_exp():
x = np_identity(2)
y = np_exp(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_exp2():
x = np_identity(2)
y = np_exp2(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_log():
x = np_identity(2)
y = np_log(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_log10():
x = np_identity(2)
y = np_log10(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_log2():
x = np_identity(2)
y = np_log2(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_fabs():
x = -np_identity(2)
y = np_fabs(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sqrt():
x = np_identity(2)
y = np_sqrt(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_rint():
x = np_identity(2)
y = np_rint(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_tan():
x = np_identity(2)
y = np_tan(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_arcsin():
x = np_identity(2)
y = np_arcsin(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_arccos():
x = np_identity(2)
y = np_arccos(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_arctan():
x = np_identity(2)
y = np_arctan(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sinh():
x = np_identity(2)
y = np_sinh(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_cosh():
x = np_identity(2)
y = np_cosh(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_tanh():
x = np_identity(2)
y = np_tanh(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_arcsinh():
x = np_identity(2)
y = np_arcsinh(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_arccosh():
x = np_identity(2)
y = np_arccosh(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_arctanh():
x = np_identity(2)
y = np_arctanh(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_expm1():
x = np_identity(2)
y = np_expm1(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_cbrt():
x = np_identity(2)
y = np_cbrt(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_erf():
x = np_identity(2)
y = sp_spec_erf(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_erfc():
x = np_identity(2)
y = sp_spec_erfc(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_gamma():
x = np_identity(2)
y = sp_spec_gamma(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_gammaln():
x = np_identity(2)
y = sp_spec_gammaln(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_j0():
x = np_identity(2)
y = sp_spec_j0(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_j1():
x = np_identity(2)
y = sp_spec_j1(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_arctan2():
x = np_identity(2)
zeros = np_zeros([2, 2])
ones = np_ones([2, 2])
atan2_x_zeros = np_arctan2(x, zeros)
atan2_x_ones = np_arctan2(x, ones)
output_ndarray_float_2(x)
output_ndarray_float_2(zeros)
output_ndarray_float_2(ones)
output_ndarray_float_2(atan2_x_zeros)
output_ndarray_float_2(atan2_x_ones)
def test_ndarray_arctan2_broadcast():
x = np_identity(2)
atan2_x_zeros = np_arctan2(x, np_zeros([2]))
atan2_x_ones = np_arctan2(x, np_ones([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(atan2_x_zeros)
output_ndarray_float_2(atan2_x_ones)
def test_ndarray_arctan2_broadcast_lhs_scalar():
x = np_identity(2)
atan2_x_zeros = np_arctan2(0.0, x)
atan2_x_ones = np_arctan2(1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(atan2_x_zeros)
output_ndarray_float_2(atan2_x_ones)
def test_ndarray_arctan2_broadcast_rhs_scalar():
x = np_identity(2)
atan2_x_zeros = np_arctan2(x, 0.0)
atan2_x_ones = np_arctan2(x, 1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(atan2_x_zeros)
output_ndarray_float_2(atan2_x_ones)
def test_ndarray_copysign():
x = np_identity(2)
ones = np_ones([2, 2])
negones = np_full([2, 2], -1.0)
copysign_x_ones = np_copysign(x, ones)
copysign_x_negones = np_copysign(x, ones)
output_ndarray_float_2(x)
output_ndarray_float_2(ones)
output_ndarray_float_2(negones)
output_ndarray_float_2(copysign_x_ones)
output_ndarray_float_2(copysign_x_negones)
def test_ndarray_copysign_broadcast():
x = np_identity(2)
copysign_x_ones = np_copysign(x, np_ones([2]))
copysign_x_negones = np_copysign(x, np_full([2], -1.0))
output_ndarray_float_2(x)
output_ndarray_float_2(copysign_x_ones)
output_ndarray_float_2(copysign_x_negones)
def test_ndarray_copysign_broadcast_lhs_scalar():
x = np_identity(2)
copysign_x_ones = np_copysign(1.0, x)
copysign_x_negones = np_copysign(-1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(copysign_x_ones)
output_ndarray_float_2(copysign_x_negones)
def test_ndarray_copysign_broadcast_rhs_scalar():
x = np_identity(2)
copysign_x_ones = np_copysign(x, 1.0)
copysign_x_negones = np_copysign(x, -1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(copysign_x_ones)
output_ndarray_float_2(copysign_x_negones)
def test_ndarray_fmax():
x = np_identity(2)
ones = np_ones([2, 2])
negones = np_full([2, 2], -1.0)
fmax_x_ones = np_fmax(x, ones)
fmax_x_negones = np_fmax(x, ones)
output_ndarray_float_2(x)
output_ndarray_float_2(ones)
output_ndarray_float_2(negones)
output_ndarray_float_2(fmax_x_ones)
output_ndarray_float_2(fmax_x_negones)
def test_ndarray_fmax_broadcast():
x = np_identity(2)
fmax_x_ones = np_fmax(x, np_ones([2]))
fmax_x_negones = np_fmax(x, np_full([2], -1.0))
output_ndarray_float_2(x)
output_ndarray_float_2(fmax_x_ones)
output_ndarray_float_2(fmax_x_negones)
def test_ndarray_fmax_broadcast_lhs_scalar():
x = np_identity(2)
fmax_x_ones = np_fmax(1.0, x)
fmax_x_negones = np_fmax(-1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(fmax_x_ones)
output_ndarray_float_2(fmax_x_negones)
def test_ndarray_fmax_broadcast_rhs_scalar():
x = np_identity(2)
fmax_x_ones = np_fmax(x, 1.0)
fmax_x_negones = np_fmax(x, -1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(fmax_x_ones)
output_ndarray_float_2(fmax_x_negones)
def test_ndarray_fmin():
x = np_identity(2)
ones = np_ones([2, 2])
negones = np_full([2, 2], -1.0)
fmin_x_ones = np_fmin(x, ones)
fmin_x_negones = np_fmin(x, ones)
output_ndarray_float_2(x)
output_ndarray_float_2(ones)
output_ndarray_float_2(negones)
output_ndarray_float_2(fmin_x_ones)
output_ndarray_float_2(fmin_x_negones)
def test_ndarray_fmin_broadcast():
x = np_identity(2)
fmin_x_ones = np_fmin(x, np_ones([2]))
fmin_x_negones = np_fmin(x, np_full([2], -1.0))
output_ndarray_float_2(x)
output_ndarray_float_2(fmin_x_ones)
output_ndarray_float_2(fmin_x_negones)
def test_ndarray_fmin_broadcast_lhs_scalar():
x = np_identity(2)
fmin_x_ones = np_fmin(1.0, x)
fmin_x_negones = np_fmin(-1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(fmin_x_ones)
output_ndarray_float_2(fmin_x_negones)
def test_ndarray_fmin_broadcast_rhs_scalar():
x = np_identity(2)
fmin_x_ones = np_fmin(x, 1.0)
fmin_x_negones = np_fmin(x, -1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(fmin_x_ones)
output_ndarray_float_2(fmin_x_negones)
def test_ndarray_ldexp():
x = np_identity(2)
zeros = np_full([2, 2], 0)
ones = np_full([2, 2], 1)
ldexp_x_zeros = np_ldexp(x, zeros)
ldexp_x_ones = np_ldexp(x, ones)
output_ndarray_float_2(x)
output_ndarray_int32_2(zeros)
output_ndarray_int32_2(ones)
output_ndarray_float_2(ldexp_x_zeros)
output_ndarray_float_2(ldexp_x_ones)
def test_ndarray_ldexp_broadcast():
x = np_identity(2)
ldexp_x_zeros = np_ldexp(x, np_full([2], 0))
ldexp_x_ones = np_ldexp(x, np_full([2], 1))
output_ndarray_float_2(x)
output_ndarray_float_2(ldexp_x_zeros)
output_ndarray_float_2(ldexp_x_ones)
def test_ndarray_ldexp_broadcast_lhs_scalar():
x = int32(np_identity(2))
ldexp_x_zeros = np_ldexp(0.0, x)
ldexp_x_ones = np_ldexp(1.0, x)
output_ndarray_int32_2(x)
output_ndarray_float_2(ldexp_x_zeros)
output_ndarray_float_2(ldexp_x_ones)
def test_ndarray_ldexp_broadcast_rhs_scalar():
x = np_identity(2)
ldexp_x_zeros = np_ldexp(x, 0)
ldexp_x_ones = np_ldexp(x, 1)
output_ndarray_float_2(x)
output_ndarray_float_2(ldexp_x_zeros)
output_ndarray_float_2(ldexp_x_ones)
def test_ndarray_hypot():
x = np_identity(2)
zeros = np_zeros([2, 2])
ones = np_ones([2, 2])
hypot_x_zeros = np_hypot(x, zeros)
hypot_x_ones = np_hypot(x, ones)
output_ndarray_float_2(x)
output_ndarray_float_2(zeros)
output_ndarray_float_2(ones)
output_ndarray_float_2(hypot_x_zeros)
output_ndarray_float_2(hypot_x_ones)
def test_ndarray_hypot_broadcast():
x = np_identity(2)
hypot_x_zeros = np_hypot(x, np_zeros([2]))
hypot_x_ones = np_hypot(x, np_ones([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(hypot_x_zeros)
output_ndarray_float_2(hypot_x_ones)
def test_ndarray_hypot_broadcast_lhs_scalar():
x = np_identity(2)
hypot_x_zeros = np_hypot(0.0, x)
hypot_x_ones = np_hypot(1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(hypot_x_zeros)
output_ndarray_float_2(hypot_x_ones)
def test_ndarray_hypot_broadcast_rhs_scalar():
x = np_identity(2)
hypot_x_zeros = np_hypot(x, 0.0)
hypot_x_ones = np_hypot(x, 1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(hypot_x_zeros)
output_ndarray_float_2(hypot_x_ones)
def test_ndarray_nextafter():
x = np_identity(2)
zeros = np_zeros([2, 2])
ones = np_ones([2, 2])
nextafter_x_zeros = np_nextafter(x, zeros)
nextafter_x_ones = np_nextafter(x, ones)
output_ndarray_float_2(x)
output_ndarray_float_2(zeros)
output_ndarray_float_2(ones)
output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_nextafter_broadcast():
x = np_identity(2)
nextafter_x_zeros = np_nextafter(x, np_zeros([2]))
nextafter_x_ones = np_nextafter(x, np_ones([2]))
output_ndarray_float_2(x)
output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_nextafter_broadcast_lhs_scalar():
x = np_identity(2)
nextafter_x_zeros = np_nextafter(0.0, x)
nextafter_x_ones = np_nextafter(1.0, x)
output_ndarray_float_2(x)
output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_nextafter_broadcast_rhs_scalar():
x = np_identity(2)
nextafter_x_zeros = np_nextafter(x, 0.0)
nextafter_x_ones = np_nextafter(x, 1.0)
output_ndarray_float_2(x)
output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
@ -739,4 +1354,76 @@ def run() -> int32:
test_ndarray_ge_broadcast_lhs_scalar() test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar() test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_int32()
test_ndarray_int64()
test_ndarray_uint32()
test_ndarray_uint64()
test_ndarray_float()
test_ndarray_bool()
test_ndarray_round()
test_ndarray_floor()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
test_ndarray_sin()
test_ndarray_cos()
test_ndarray_exp()
test_ndarray_exp2()
test_ndarray_log()
test_ndarray_log10()
test_ndarray_log2()
test_ndarray_fabs()
test_ndarray_sqrt()
test_ndarray_rint()
test_ndarray_tan()
test_ndarray_arcsin()
test_ndarray_arccos()
test_ndarray_arctan()
test_ndarray_sinh()
test_ndarray_cosh()
test_ndarray_tanh()
test_ndarray_arcsinh()
test_ndarray_arccosh()
test_ndarray_arctanh()
test_ndarray_expm1()
test_ndarray_cbrt()
test_ndarray_erf()
test_ndarray_erfc()
test_ndarray_gamma()
test_ndarray_gammaln()
test_ndarray_j0()
test_ndarray_j1()
test_ndarray_arctan2()
test_ndarray_arctan2_broadcast()
test_ndarray_arctan2_broadcast_lhs_scalar()
test_ndarray_arctan2_broadcast_rhs_scalar()
test_ndarray_copysign()
test_ndarray_copysign_broadcast()
test_ndarray_copysign_broadcast_lhs_scalar()
test_ndarray_copysign_broadcast_rhs_scalar()
test_ndarray_fmax()
test_ndarray_fmax_broadcast()
test_ndarray_fmax_broadcast_lhs_scalar()
test_ndarray_fmax_broadcast_rhs_scalar()
test_ndarray_fmin()
test_ndarray_fmin_broadcast()
test_ndarray_fmin_broadcast_lhs_scalar()
test_ndarray_fmin_broadcast_rhs_scalar()
test_ndarray_ldexp()
test_ndarray_ldexp_broadcast()
test_ndarray_ldexp_broadcast_lhs_scalar()
test_ndarray_ldexp_broadcast_rhs_scalar()
test_ndarray_hypot()
test_ndarray_hypot_broadcast()
test_ndarray_hypot_broadcast_lhs_scalar()
test_ndarray_hypot_broadcast_rhs_scalar()
test_ndarray_nextafter()
test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar()
return 0 return 0