From a920fe05017fc5b87aaabccc34f73d1e53e031ef Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 27 Mar 2024 12:57:11 +0800 Subject: [PATCH] core: Implement elementwise comparison operators --- nac3core/src/codegen/expr.rs | 85 +++++++- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 40 +++- nac3core/src/typecheck/type_inferencer/mod.rs | 29 ++- nac3standalone/demo/src/ndarray.py | 192 ++++++++++++++++++ 9 files changed, 343 insertions(+), 17 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 2c05d66a2..bec094c96 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1390,12 +1390,92 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>( /// Generates LLVM IR for a comparison operator expression using the [`Type`] and /// [LLVM value][`BasicValueEnum`] of the operands. pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( - _generator: &mut G, + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, left: (Option, BasicValueEnum<'ctx>), ops: &[ast::Cmpop], comparators: &[(Option, BasicValueEnum<'ctx>)], ) -> Result>, String> { + debug_assert_eq!(comparators.len(), ops.len()); + + if comparators.len() == 1 { + let left_ty = ctx.unifier.get_representative(left.0.unwrap()); + let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap()); + + if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (Some(left_ty), lhs) = left else { unreachable!() }; + let (Some(right_ty), rhs) = comparators[0] else { unreachable!() }; + let op = ops[0].clone(); + + let is_ndarray1 = left_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = right_ty.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + return if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + + assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + let left_val = NDArrayValue::from_ptr_val( + lhs.into_pointer_value(), + llvm_usize, + None + ); + let res = numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + (left_val.as_ptr_value().into(), false), + (rhs, false), + |generator, ctx, (lhs, rhs)| { + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(ndarray_dtype1), lhs), + &[op.clone()], + &[(Some(ndarray_dtype2), rhs)], + )?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?; + + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; + + Ok(Some(res.as_ptr_value().into())) + } else { + let (ndarray_dtype, _) = unpack_ndarray_var_tys( + &mut ctx.unifier, + if is_ndarray1 { left_ty } else { right_ty }, + ); + let res = numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ctx.primitives.bool, + None, + (lhs, !is_ndarray1), + (rhs, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(ndarray_dtype), lhs), + &[op.clone()], + &[(Some(ndarray_dtype), rhs)], + )?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?; + + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; + + Ok(Some(res.as_ptr_value().into())) + } + } + } + let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let (left_ty, lhs) = lhs; @@ -1451,7 +1531,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let lhs = lhs.into_float_value(); let rhs = rhs.into_float_value(); - + let op = match op { ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, @@ -1465,6 +1545,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } else { unimplemented!() }; + Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) })?; diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index a9dc4ad1c..415a98e1e 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [156]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index dc36b54b9..41b317822 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar145]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar145\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index d6adcee10..f705da720 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [158]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [163]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 55767a80c..79520bafe 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar144, typevar145]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar144\", \"typevar145\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index f4f96f2d6..eee42f41a 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [164]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [172]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 38d480c3f..d785e92a3 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -483,6 +483,33 @@ pub fn typeof_unaryop( }) } +/// Returns the return type given a comparison operator and its primitive operands. +pub fn typeof_cmpop( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + _op: &Cmpop, + lhs: Type, + rhs: Type, +) -> Result, String> { + let is_left_ndarray = lhs + .obj_id(unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_right_ndarray = rhs + .obj_id(unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + Ok(Some(if is_left_ndarray || is_right_ndarray { + let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; + let (_, ndims) = unpack_ndarray_var_tys(unifier, brd); + + make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims)) + } else if unifier.unioned(lhs, rhs) { + primitives.bool + } else { + return Ok(None) + })) +} + pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { let PrimitiveStore { int32: int32_t, @@ -508,8 +535,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_mod(unifier, store, t, &[t, ndarray_int_t], None); impl_invert(unifier, store, t, Some(t)); impl_not(unifier, store, t, Some(bool_t)); - impl_comparison(unifier, store, t, &[t], Some(bool_t)); - impl_eq(unifier, store, t, &[t], Some(bool_t)); + impl_comparison(unifier, store, t, &[t, ndarray_int_t], None); + impl_eq(unifier, store, t, &[t, ndarray_int_t], None); } for t in [int32_t, int64_t] { impl_sign(unifier, store, t, Some(t)); @@ -525,12 +552,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_sign(unifier, store, float_t, Some(float_t)); impl_not(unifier, store, float_t, Some(bool_t)); - impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t)); - impl_eq(unifier, store, float_t, &[float_t], Some(bool_t)); + impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None); /* bool ======== */ + let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None); impl_not(unifier, store, bool_t, Some(bool_t)); - impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t)); + impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); /* ndarray ===== */ let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); @@ -544,4 +572,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); + impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c19d3d18c..6921b92c2 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1271,22 +1271,45 @@ impl<'a> Inferencer<'a> { ops: &[ast::Cmpop], comparators: &[ast::Expr>], ) -> InferenceResult { - let boolean = self.primitives.bool; + if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) { + return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")])) + } + for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { let method = comparison_name(c) .ok_or_else(|| HashSet::from([ "unsupported comparator".to_string() ]))? .into(); + + let ret = typeof_cmpop( + self.unifier, + self.primitives, + c, + a.custom.unwrap(), + b.custom.unwrap(), + ).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?; + self.build_method_call( a.location, method, a.custom.unwrap(), vec![b.custom.unwrap()], - Some(boolean), + ret, )?; } - Ok(boolean) + + let res_lhs = comparators.iter().rev().nth(1).unwrap_or(left); + let res_rhs = comparators.iter().rev().nth(0).unwrap(); + let res_op = ops.iter().rev().nth(0).unwrap(); + + Ok(typeof_cmpop( + self.unifier, + self.primitives, + res_op, + res_lhs.custom.unwrap(), + res_rhs.custom.unwrap(), + ).unwrap().unwrap()) } /// Infers the type of a subscript expression on an `ndarray`. diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 454309752..6dcc7ac27 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -455,6 +455,174 @@ def test_ndarray_inv(): output_ndarray_int32_2(x_int32) output_ndarray_int32_2(y_int32) +def test_ndarray_eq(): + x = np_identity(2) + y = x == np_full([2, 2], 0.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_eq_broadcast(): + x = np_identity(2) + y = x == np_full([2], 0.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_eq_broadcast_lhs_scalar(): + x = np_identity(2) + y = 0.0 == x + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_eq_broadcast_rhs_scalar(): + x = np_identity(2) + y = x == 0.0 + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ne(): + x = np_identity(2) + y = x != np_full([2, 2], 0.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ne_broadcast(): + x = np_identity(2) + y = x != np_full([2], 0.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ne_broadcast_lhs_scalar(): + x = np_identity(2) + y = 0.0 != x + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ne_broadcast_rhs_scalar(): + x = np_identity(2) + y = x != 0.0 + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_lt(): + x = np_identity(2) + y = x < np_full([2, 2], 1.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_lt_broadcast(): + x = np_identity(2) + y = x < np_full([2], 1.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_lt_broadcast_lhs_scalar(): + x = np_identity(2) + y = 1.0 < x + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_lt_broadcast_rhs_scalar(): + x = np_identity(2) + y = x < 1.0 + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_le(): + x = np_identity(2) + y = x <= np_full([2, 2], 0.5) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_le_broadcast(): + x = np_identity(2) + y = x <= np_full([2], 0.5) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_le_broadcast_lhs_scalar(): + x = np_identity(2) + y = 0.5 <= x + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_le_broadcast_rhs_scalar(): + x = np_identity(2) + y = x <= 0.5 + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_gt(): + x = np_identity(2) + y = x > np_full([2, 2], 0.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_gt_broadcast(): + x = np_identity(2) + y = x > np_full([2], 0.0) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_gt_broadcast_lhs_scalar(): + x = np_identity(2) + y = 0.0 > x + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_gt_broadcast_rhs_scalar(): + x = np_identity(2) + y = x > 0.0 + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ge(): + x = np_identity(2) + y = x >= np_full([2, 2], 0.5) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ge_broadcast(): + x = np_identity(2) + y = x >= np_full([2], 0.5) + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ge_broadcast_lhs_scalar(): + x = np_identity(2) + y = 0.5 >= x + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + +def test_ndarray_ge_broadcast_rhs_scalar(): + x = np_identity(2) + y = x >= 0.5 + + output_ndarray_float_2(x) + output_ndarray_bool_2(y) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -517,5 +685,29 @@ def run() -> int32: test_ndarray_pos() test_ndarray_neg() test_ndarray_inv() + test_ndarray_eq() + test_ndarray_eq_broadcast() + test_ndarray_eq_broadcast_lhs_scalar() + test_ndarray_eq_broadcast_rhs_scalar() + test_ndarray_ne() + test_ndarray_ne_broadcast() + test_ndarray_ne_broadcast_lhs_scalar() + test_ndarray_ne_broadcast_rhs_scalar() + test_ndarray_lt() + test_ndarray_lt_broadcast() + test_ndarray_lt_broadcast_lhs_scalar() + test_ndarray_lt_broadcast_rhs_scalar() + test_ndarray_lt() + test_ndarray_le_broadcast() + test_ndarray_le_broadcast_lhs_scalar() + test_ndarray_le_broadcast_rhs_scalar() + test_ndarray_gt() + test_ndarray_gt_broadcast() + test_ndarray_gt_broadcast_lhs_scalar() + test_ndarray_gt_broadcast_rhs_scalar() + test_ndarray_gt() + test_ndarray_ge_broadcast() + test_ndarray_ge_broadcast_lhs_scalar() + test_ndarray_ge_broadcast_rhs_scalar() return 0