core: Implement elementwise comparison operators
This commit is contained in:
parent
727a1886b3
commit
a920fe0501
@ -1390,12 +1390,92 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
|
|||||||
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
|
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
|
||||||
/// [LLVM value][`BasicValueEnum`] of the operands.
|
/// [LLVM value][`BasicValueEnum`] of the operands.
|
||||||
pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
_generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
left: (Option<Type>, BasicValueEnum<'ctx>),
|
left: (Option<Type>, BasicValueEnum<'ctx>),
|
||||||
ops: &[ast::Cmpop],
|
ops: &[ast::Cmpop],
|
||||||
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
|
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
|
debug_assert_eq!(comparators.len(), ops.len());
|
||||||
|
|
||||||
|
if comparators.len() == 1 {
|
||||||
|
let left_ty = ctx.unifier.get_representative(left.0.unwrap());
|
||||||
|
let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap());
|
||||||
|
|
||||||
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let (Some(left_ty), lhs) = left else { unreachable!() };
|
||||||
|
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
|
||||||
|
let op = ops[0].clone();
|
||||||
|
|
||||||
|
let is_ndarray1 = left_ty.obj_id(&ctx.unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
let is_ndarray2 = right_ty.obj_id(&ctx.unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
return if is_ndarray1 && is_ndarray2 {
|
||||||
|
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
||||||
|
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
||||||
|
|
||||||
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
|
let left_val = NDArrayValue::from_ptr_val(
|
||||||
|
lhs.into_pointer_value(),
|
||||||
|
llvm_usize,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.primitives.bool,
|
||||||
|
None,
|
||||||
|
(left_val.as_ptr_value().into(), false),
|
||||||
|
(rhs, false),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
let val = gen_cmpop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(Some(ndarray_dtype1), lhs),
|
||||||
|
&[op.clone()],
|
||||||
|
&[(Some(ndarray_dtype2), rhs)],
|
||||||
|
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
|
||||||
|
|
||||||
|
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Some(res.as_ptr_value().into()))
|
||||||
|
} else {
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
||||||
|
&mut ctx.unifier,
|
||||||
|
if is_ndarray1 { left_ty } else { right_ty },
|
||||||
|
);
|
||||||
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.primitives.bool,
|
||||||
|
None,
|
||||||
|
(lhs, !is_ndarray1),
|
||||||
|
(rhs, !is_ndarray2),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
let val = gen_cmpop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(Some(ndarray_dtype), lhs),
|
||||||
|
&[op.clone()],
|
||||||
|
&[(Some(ndarray_dtype), rhs)],
|
||||||
|
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
|
||||||
|
|
||||||
|
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Some(res.as_ptr_value().into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
||||||
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
||||||
let (left_ty, lhs) = lhs;
|
let (left_ty, lhs) = lhs;
|
||||||
@ -1465,6 +1545,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [156]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar145]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar145\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [158]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [163]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar144, typevar145]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar144\", \"typevar145\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [164]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [172]\n}\n",
|
||||||
]
|
]
|
||||||
|
@ -483,6 +483,33 @@ pub fn typeof_unaryop(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the return type given a comparison operator and its primitive operands.
|
||||||
|
pub fn typeof_cmpop(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
_op: &Cmpop,
|
||||||
|
lhs: Type,
|
||||||
|
rhs: Type,
|
||||||
|
) -> Result<Option<Type>, String> {
|
||||||
|
let is_left_ndarray = lhs
|
||||||
|
.obj_id(unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
let is_right_ndarray = rhs
|
||||||
|
.obj_id(unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
||||||
|
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
||||||
|
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
|
||||||
|
|
||||||
|
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
||||||
|
} else if unifier.unioned(lhs, rhs) {
|
||||||
|
primitives.bool
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||||
let PrimitiveStore {
|
let PrimitiveStore {
|
||||||
int32: int32_t,
|
int32: int32_t,
|
||||||
@ -508,8 +535,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
||||||
impl_invert(unifier, store, t, Some(t));
|
impl_invert(unifier, store, t, Some(t));
|
||||||
impl_not(unifier, store, t, Some(bool_t));
|
impl_not(unifier, store, t, Some(bool_t));
|
||||||
impl_comparison(unifier, store, t, &[t], Some(bool_t));
|
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
|
||||||
impl_eq(unifier, store, t, &[t], Some(bool_t));
|
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
|
||||||
}
|
}
|
||||||
for t in [int32_t, int64_t] {
|
for t in [int32_t, int64_t] {
|
||||||
impl_sign(unifier, store, t, Some(t));
|
impl_sign(unifier, store, t, Some(t));
|
||||||
@ -525,12 +552,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||||
impl_sign(unifier, store, float_t, Some(float_t));
|
impl_sign(unifier, store, float_t, Some(float_t));
|
||||||
impl_not(unifier, store, float_t, Some(bool_t));
|
impl_not(unifier, store, float_t, Some(bool_t));
|
||||||
impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t));
|
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||||
impl_eq(unifier, store, float_t, &[float_t], Some(bool_t));
|
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||||
|
|
||||||
/* bool ======== */
|
/* bool ======== */
|
||||||
|
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
||||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||||
impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t));
|
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||||
|
|
||||||
/* ndarray ===== */
|
/* ndarray ===== */
|
||||||
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||||
@ -544,4 +572,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
|
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
|
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
}
|
}
|
||||||
|
@ -1271,22 +1271,45 @@ impl<'a> Inferencer<'a> {
|
|||||||
ops: &[ast::Cmpop],
|
ops: &[ast::Cmpop],
|
||||||
comparators: &[ast::Expr<Option<Type>>],
|
comparators: &[ast::Expr<Option<Type>>],
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
let boolean = self.primitives.bool;
|
if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
|
||||||
|
return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")]))
|
||||||
|
}
|
||||||
|
|
||||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||||
let method = comparison_name(c)
|
let method = comparison_name(c)
|
||||||
.ok_or_else(|| HashSet::from([
|
.ok_or_else(|| HashSet::from([
|
||||||
"unsupported comparator".to_string()
|
"unsupported comparator".to_string()
|
||||||
]))?
|
]))?
|
||||||
.into();
|
.into();
|
||||||
|
|
||||||
|
let ret = typeof_cmpop(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
c,
|
||||||
|
a.custom.unwrap(),
|
||||||
|
b.custom.unwrap(),
|
||||||
|
).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
|
||||||
|
|
||||||
self.build_method_call(
|
self.build_method_call(
|
||||||
a.location,
|
a.location,
|
||||||
method,
|
method,
|
||||||
a.custom.unwrap(),
|
a.custom.unwrap(),
|
||||||
vec![b.custom.unwrap()],
|
vec![b.custom.unwrap()],
|
||||||
Some(boolean),
|
ret,
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
Ok(boolean)
|
|
||||||
|
let res_lhs = comparators.iter().rev().nth(1).unwrap_or(left);
|
||||||
|
let res_rhs = comparators.iter().rev().nth(0).unwrap();
|
||||||
|
let res_op = ops.iter().rev().nth(0).unwrap();
|
||||||
|
|
||||||
|
Ok(typeof_cmpop(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
res_op,
|
||||||
|
res_lhs.custom.unwrap(),
|
||||||
|
res_rhs.custom.unwrap(),
|
||||||
|
).unwrap().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infers the type of a subscript expression on an `ndarray`.
|
/// Infers the type of a subscript expression on an `ndarray`.
|
||||||
|
@ -455,6 +455,174 @@ def test_ndarray_inv():
|
|||||||
output_ndarray_int32_2(x_int32)
|
output_ndarray_int32_2(x_int32)
|
||||||
output_ndarray_int32_2(y_int32)
|
output_ndarray_int32_2(y_int32)
|
||||||
|
|
||||||
|
def test_ndarray_eq():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x == np_full([2, 2], 0.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_eq_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x == np_full([2], 0.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_eq_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = 0.0 == x
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_eq_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x == 0.0
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ne():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x != np_full([2, 2], 0.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ne_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x != np_full([2], 0.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ne_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = 0.0 != x
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ne_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x != 0.0
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_lt():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x < np_full([2, 2], 1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_lt_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x < np_full([2], 1.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_lt_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = 1.0 < x
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_lt_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x < 1.0
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_le():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x <= np_full([2, 2], 0.5)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_le_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x <= np_full([2], 0.5)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_le_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = 0.5 <= x
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_le_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x <= 0.5
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_gt():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x > np_full([2, 2], 0.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_gt_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x > np_full([2], 0.0)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_gt_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = 0.0 > x
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_gt_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x > 0.0
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ge():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x >= np_full([2, 2], 0.5)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ge_broadcast():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x >= np_full([2], 0.5)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ge_broadcast_lhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = 0.5 >= x
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_ge_broadcast_rhs_scalar():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x >= 0.5
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
@ -517,5 +685,29 @@ def run() -> int32:
|
|||||||
test_ndarray_pos()
|
test_ndarray_pos()
|
||||||
test_ndarray_neg()
|
test_ndarray_neg()
|
||||||
test_ndarray_inv()
|
test_ndarray_inv()
|
||||||
|
test_ndarray_eq()
|
||||||
|
test_ndarray_eq_broadcast()
|
||||||
|
test_ndarray_eq_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_eq_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_ne()
|
||||||
|
test_ndarray_ne_broadcast()
|
||||||
|
test_ndarray_ne_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_ne_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_lt()
|
||||||
|
test_ndarray_lt_broadcast()
|
||||||
|
test_ndarray_lt_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_lt_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_lt()
|
||||||
|
test_ndarray_le_broadcast()
|
||||||
|
test_ndarray_le_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_le_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_gt()
|
||||||
|
test_ndarray_gt_broadcast()
|
||||||
|
test_ndarray_gt_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_gt_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_gt()
|
||||||
|
test_ndarray_ge_broadcast()
|
||||||
|
test_ndarray_ge_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_ge_broadcast_rhs_scalar()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
Loading…
Reference in New Issue
Block a user