forked from M-Labs/nac3
core: Implement elementwise comparison operators
This commit is contained in:
parent
727a1886b3
commit
a920fe0501
|
@ -1390,12 +1390,92 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
|
|||
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
|
||||
/// [LLVM value][`BasicValueEnum`] of the operands.
|
||||
pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
_generator: &mut G,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
left: (Option<Type>, BasicValueEnum<'ctx>),
|
||||
ops: &[ast::Cmpop],
|
||||
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
debug_assert_eq!(comparators.len(), ops.len());
|
||||
|
||||
if comparators.len() == 1 {
|
||||
let left_ty = ctx.unifier.get_representative(left.0.unwrap());
|
||||
let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap());
|
||||
|
||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (Some(left_ty), lhs) = left else { unreachable!() };
|
||||
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
|
||||
let op = ops[0].clone();
|
||||
|
||||
let is_ndarray1 = left_ty.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_ndarray2 = right_ty.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
return if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
let left_val = NDArrayValue::from_ptr_val(
|
||||
lhs.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
(left_val.as_ptr_value().into(), false),
|
||||
(rhs, false),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
let val = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(ndarray_dtype1), lhs),
|
||||
&[op.clone()],
|
||||
&[(Some(ndarray_dtype2), rhs)],
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
|
||||
|
||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
} else {
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
||||
&mut ctx.unifier,
|
||||
if is_ndarray1 { left_ty } else { right_ty },
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
(lhs, !is_ndarray1),
|
||||
(rhs, !is_ndarray2),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
let val = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(ndarray_dtype), lhs),
|
||||
&[op.clone()],
|
||||
&[(Some(ndarray_dtype), rhs)],
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
|
||||
|
||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
||||
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
||||
let (left_ty, lhs) = lhs;
|
||||
|
@ -1451,7 +1531,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
|
||||
let lhs = lhs.into_float_value();
|
||||
let rhs = rhs.into_float_value();
|
||||
|
||||
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||
|
@ -1465,6 +1545,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
} else {
|
||||
unimplemented!()
|
||||
};
|
||||
|
||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||
})?;
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [156]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar145]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar145\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [158]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [163]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar144, typevar145]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar144\", \"typevar145\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [164]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [172]\n}\n",
|
||||
]
|
||||
|
|
|
@ -483,6 +483,33 @@ pub fn typeof_unaryop(
|
|||
})
|
||||
}
|
||||
|
||||
/// Returns the return type given a comparison operator and its primitive operands.
|
||||
pub fn typeof_cmpop(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
_op: &Cmpop,
|
||||
lhs: Type,
|
||||
rhs: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let is_left_ndarray = lhs
|
||||
.obj_id(unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_right_ndarray = rhs
|
||||
.obj_id(unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
||||
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
||||
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
|
||||
|
||||
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
primitives.bool
|
||||
} else {
|
||||
return Ok(None)
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||
let PrimitiveStore {
|
||||
int32: int32_t,
|
||||
|
@ -508,8 +535,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_invert(unifier, store, t, Some(t));
|
||||
impl_not(unifier, store, t, Some(bool_t));
|
||||
impl_comparison(unifier, store, t, &[t], Some(bool_t));
|
||||
impl_eq(unifier, store, t, &[t], Some(bool_t));
|
||||
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
}
|
||||
for t in [int32_t, int64_t] {
|
||||
impl_sign(unifier, store, t, Some(t));
|
||||
|
@ -525,12 +552,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_sign(unifier, store, float_t, Some(float_t));
|
||||
impl_not(unifier, store, float_t, Some(bool_t));
|
||||
impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t));
|
||||
impl_eq(unifier, store, float_t, &[float_t], Some(bool_t));
|
||||
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
|
||||
/* bool ======== */
|
||||
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||
|
||||
/* ndarray ===== */
|
||||
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||
|
@ -544,4 +572,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
}
|
||||
|
|
|
@ -1271,22 +1271,45 @@ impl<'a> Inferencer<'a> {
|
|||
ops: &[ast::Cmpop],
|
||||
comparators: &[ast::Expr<Option<Type>>],
|
||||
) -> InferenceResult {
|
||||
let boolean = self.primitives.bool;
|
||||
if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
|
||||
return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")]))
|
||||
}
|
||||
|
||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||
let method = comparison_name(c)
|
||||
.ok_or_else(|| HashSet::from([
|
||||
"unsupported comparator".to_string()
|
||||
]))?
|
||||
.into();
|
||||
|
||||
let ret = typeof_cmpop(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
c,
|
||||
a.custom.unwrap(),
|
||||
b.custom.unwrap(),
|
||||
).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
|
||||
|
||||
self.build_method_call(
|
||||
a.location,
|
||||
method,
|
||||
a.custom.unwrap(),
|
||||
vec![b.custom.unwrap()],
|
||||
Some(boolean),
|
||||
ret,
|
||||
)?;
|
||||
}
|
||||
Ok(boolean)
|
||||
|
||||
let res_lhs = comparators.iter().rev().nth(1).unwrap_or(left);
|
||||
let res_rhs = comparators.iter().rev().nth(0).unwrap();
|
||||
let res_op = ops.iter().rev().nth(0).unwrap();
|
||||
|
||||
Ok(typeof_cmpop(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
res_op,
|
||||
res_lhs.custom.unwrap(),
|
||||
res_rhs.custom.unwrap(),
|
||||
).unwrap().unwrap())
|
||||
}
|
||||
|
||||
/// Infers the type of a subscript expression on an `ndarray`.
|
||||
|
|
|
@ -455,6 +455,174 @@ def test_ndarray_inv():
|
|||
output_ndarray_int32_2(x_int32)
|
||||
output_ndarray_int32_2(y_int32)
|
||||
|
||||
def test_ndarray_eq():
|
||||
x = np_identity(2)
|
||||
y = x == np_full([2, 2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_eq_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x == np_full([2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_eq_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.0 == x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_eq_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x == 0.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne():
|
||||
x = np_identity(2)
|
||||
y = x != np_full([2, 2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x != np_full([2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.0 != x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x != 0.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt():
|
||||
x = np_identity(2)
|
||||
y = x < np_full([2, 2], 1.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x < np_full([2], 1.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 1.0 < x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x < 1.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le():
|
||||
x = np_identity(2)
|
||||
y = x <= np_full([2, 2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x <= np_full([2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.5 <= x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x <= 0.5
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt():
|
||||
x = np_identity(2)
|
||||
y = x > np_full([2, 2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x > np_full([2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.0 > x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x > 0.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge():
|
||||
x = np_identity(2)
|
||||
y = x >= np_full([2, 2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x >= np_full([2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.5 >= x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x >= 0.5
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -517,5 +685,29 @@ def run() -> int32:
|
|||
test_ndarray_pos()
|
||||
test_ndarray_neg()
|
||||
test_ndarray_inv()
|
||||
test_ndarray_eq()
|
||||
test_ndarray_eq_broadcast()
|
||||
test_ndarray_eq_broadcast_lhs_scalar()
|
||||
test_ndarray_eq_broadcast_rhs_scalar()
|
||||
test_ndarray_ne()
|
||||
test_ndarray_ne_broadcast()
|
||||
test_ndarray_ne_broadcast_lhs_scalar()
|
||||
test_ndarray_ne_broadcast_rhs_scalar()
|
||||
test_ndarray_lt()
|
||||
test_ndarray_lt_broadcast()
|
||||
test_ndarray_lt_broadcast_lhs_scalar()
|
||||
test_ndarray_lt_broadcast_rhs_scalar()
|
||||
test_ndarray_lt()
|
||||
test_ndarray_le_broadcast()
|
||||
test_ndarray_le_broadcast_lhs_scalar()
|
||||
test_ndarray_le_broadcast_rhs_scalar()
|
||||
test_ndarray_gt()
|
||||
test_ndarray_gt_broadcast()
|
||||
test_ndarray_gt_broadcast_lhs_scalar()
|
||||
test_ndarray_gt_broadcast_rhs_scalar()
|
||||
test_ndarray_gt()
|
||||
test_ndarray_ge_broadcast()
|
||||
test_ndarray_ge_broadcast_lhs_scalar()
|
||||
test_ndarray_ge_broadcast_rhs_scalar()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue