1
0
forked from M-Labs/nac3

core: Implement elementwise comparison operators

This commit is contained in:
David Mak 2024-03-27 12:57:11 +08:00
parent 727a1886b3
commit a920fe0501
9 changed files with 343 additions and 17 deletions

View File

@ -1390,12 +1390,92 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and /// Generates LLVM IR for a comparison operator expression using the [`Type`] and
/// [LLVM value][`BasicValueEnum`] of the operands. /// [LLVM value][`BasicValueEnum`] of the operands.
pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
_generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
left: (Option<Type>, BasicValueEnum<'ctx>), left: (Option<Type>, BasicValueEnum<'ctx>),
ops: &[ast::Cmpop], ops: &[ast::Cmpop],
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)], comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
debug_assert_eq!(comparators.len(), ops.len());
if comparators.len() == 1 {
let left_ty = ctx.unifier.get_representative(left.0.unwrap());
let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap());
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (Some(left_ty), lhs) = left else { unreachable!() };
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
let op = ops[0].clone();
let is_ndarray1 = left_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = right_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val = NDArrayValue::from_ptr_val(
lhs.into_pointer_value(),
llvm_usize,
None
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
(left_val.as_ptr_value().into(), false),
(rhs, false),
|generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype1), lhs),
&[op.clone()],
&[(Some(ndarray_dtype2), rhs)],
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
Ok(Some(res.as_ptr_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
if is_ndarray1 { left_ty } else { right_ty },
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
(lhs, !is_ndarray1),
(rhs, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype), lhs),
&[op.clone()],
&[(Some(ndarray_dtype), rhs)],
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
Ok(Some(res.as_ptr_value().into()))
}
}
}
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| { .fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
let (left_ty, lhs) = lhs; let (left_ty, lhs) = lhs;
@ -1451,7 +1531,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let lhs = lhs.into_float_value(); let lhs = lhs.into_float_value();
let rhs = rhs.into_float_value(); let rhs = rhs.into_float_value();
let op = match op { let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
@ -1465,6 +1545,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
} else { } else {
unimplemented!() unimplemented!()
}; };
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
})?; })?;

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [156]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar145]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar145\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [158]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [163]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar144, typevar145]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar144\", \"typevar145\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [164]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [172]\n}\n",
] ]

View File

@ -483,6 +483,33 @@ pub fn typeof_unaryop(
}) })
} }
/// Returns the return type given a comparison operator and its primitive operands.
pub fn typeof_cmpop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
_op: &Cmpop,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
return Ok(None)
}))
}
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
let PrimitiveStore { let PrimitiveStore {
int32: int32_t, int32: int32_t,
@ -508,8 +535,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_mod(unifier, store, t, &[t, ndarray_int_t], None); impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t, Some(t)); impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t)); impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t], Some(bool_t)); impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
impl_eq(unifier, store, t, &[t], Some(bool_t)); impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
} }
for t in [int32_t, int64_t] { for t in [int32_t, int64_t] {
impl_sign(unifier, store, t, Some(t)); impl_sign(unifier, store, t, Some(t));
@ -525,12 +552,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t, Some(float_t)); impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t)); impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t)); impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_eq(unifier, store, float_t, &[float_t], Some(bool_t)); impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
/* bool ======== */ /* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_not(unifier, store, bool_t, Some(bool_t)); impl_not(unifier, store, bool_t, Some(bool_t));
impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t)); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* ndarray ===== */ /* ndarray ===== */
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
@ -544,4 +572,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
} }

View File

@ -1271,22 +1271,45 @@ impl<'a> Inferencer<'a> {
ops: &[ast::Cmpop], ops: &[ast::Cmpop],
comparators: &[ast::Expr<Option<Type>>], comparators: &[ast::Expr<Option<Type>>],
) -> InferenceResult { ) -> InferenceResult {
let boolean = self.primitives.bool; if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")]))
}
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = comparison_name(c) let method = comparison_name(c)
.ok_or_else(|| HashSet::from([ .ok_or_else(|| HashSet::from([
"unsupported comparator".to_string() "unsupported comparator".to_string()
]))? ]))?
.into(); .into();
let ret = typeof_cmpop(
self.unifier,
self.primitives,
c,
a.custom.unwrap(),
b.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
self.build_method_call( self.build_method_call(
a.location, a.location,
method, method,
a.custom.unwrap(), a.custom.unwrap(),
vec![b.custom.unwrap()], vec![b.custom.unwrap()],
Some(boolean), ret,
)?; )?;
} }
Ok(boolean)
let res_lhs = comparators.iter().rev().nth(1).unwrap_or(left);
let res_rhs = comparators.iter().rev().nth(0).unwrap();
let res_op = ops.iter().rev().nth(0).unwrap();
Ok(typeof_cmpop(
self.unifier,
self.primitives,
res_op,
res_lhs.custom.unwrap(),
res_rhs.custom.unwrap(),
).unwrap().unwrap())
} }
/// Infers the type of a subscript expression on an `ndarray`. /// Infers the type of a subscript expression on an `ndarray`.

View File

@ -455,6 +455,174 @@ def test_ndarray_inv():
output_ndarray_int32_2(x_int32) output_ndarray_int32_2(x_int32)
output_ndarray_int32_2(y_int32) output_ndarray_int32_2(y_int32)
def test_ndarray_eq():
x = np_identity(2)
y = x == np_full([2, 2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_eq_broadcast():
x = np_identity(2)
y = x == np_full([2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_eq_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.0 == x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_eq_broadcast_rhs_scalar():
x = np_identity(2)
y = x == 0.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne():
x = np_identity(2)
y = x != np_full([2, 2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne_broadcast():
x = np_identity(2)
y = x != np_full([2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.0 != x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne_broadcast_rhs_scalar():
x = np_identity(2)
y = x != 0.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt():
x = np_identity(2)
y = x < np_full([2, 2], 1.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt_broadcast():
x = np_identity(2)
y = x < np_full([2], 1.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt_broadcast_lhs_scalar():
x = np_identity(2)
y = 1.0 < x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt_broadcast_rhs_scalar():
x = np_identity(2)
y = x < 1.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le():
x = np_identity(2)
y = x <= np_full([2, 2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le_broadcast():
x = np_identity(2)
y = x <= np_full([2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.5 <= x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le_broadcast_rhs_scalar():
x = np_identity(2)
y = x <= 0.5
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt():
x = np_identity(2)
y = x > np_full([2, 2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt_broadcast():
x = np_identity(2)
y = x > np_full([2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.0 > x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt_broadcast_rhs_scalar():
x = np_identity(2)
y = x > 0.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge():
x = np_identity(2)
y = x >= np_full([2, 2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge_broadcast():
x = np_identity(2)
y = x >= np_full([2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.5 >= x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge_broadcast_rhs_scalar():
x = np_identity(2)
y = x >= 0.5
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
@ -517,5 +685,29 @@ def run() -> int32:
test_ndarray_pos() test_ndarray_pos()
test_ndarray_neg() test_ndarray_neg()
test_ndarray_inv() test_ndarray_inv()
test_ndarray_eq()
test_ndarray_eq_broadcast()
test_ndarray_eq_broadcast_lhs_scalar()
test_ndarray_eq_broadcast_rhs_scalar()
test_ndarray_ne()
test_ndarray_ne_broadcast()
test_ndarray_ne_broadcast_lhs_scalar()
test_ndarray_ne_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_lt_broadcast()
test_ndarray_lt_broadcast_lhs_scalar()
test_ndarray_lt_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_le_broadcast()
test_ndarray_le_broadcast_lhs_scalar()
test_ndarray_le_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_gt_broadcast()
test_ndarray_gt_broadcast_lhs_scalar()
test_ndarray_gt_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_ge_broadcast()
test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar()
return 0 return 0