forked from M-Labs/nac3
core/ndstrides: implement cmpop
This commit is contained in:
parent
5f143d2f2f
commit
9bf0e2cbf4
|
@ -1,8 +1,8 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{
|
classes::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
|
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue,
|
||||||
ProxyValue, RangeValue, UntypedArrayLikeAccessor,
|
RangeValue, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
||||||
|
@ -11,7 +11,7 @@ use crate::{
|
||||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||||
call_int_umin, call_memcpy_generic,
|
call_int_umin, call_memcpy_generic,
|
||||||
},
|
},
|
||||||
need_sret, numpy,
|
need_sret,
|
||||||
object::ndarray::{NDArrayOut, ScalarOrNDArray},
|
object::ndarray::{NDArrayOut, ScalarOrNDArray},
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
|
@ -20,7 +20,7 @@ use crate::{
|
||||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||||
},
|
},
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||||
|
@ -1842,85 +1842,46 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let (Some(left_ty), left) = left else { unreachable!() };
|
||||||
|
let (Some(right_ty), right) = comparators[0] else { unreachable!() };
|
||||||
let (Some(left_ty), lhs) = left else { unreachable!() };
|
|
||||||
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
|
|
||||||
let op = ops[0];
|
let op = ops[0];
|
||||||
|
|
||||||
let is_ndarray1 =
|
let left = AnyObject { value: left, ty: left_ty };
|
||||||
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let left =
|
||||||
let is_ndarray2 =
|
ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx);
|
||||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
|
||||||
|
|
||||||
return if is_ndarray1 && is_ndarray2 {
|
let right = AnyObject { value: right, ty: right_ty };
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
let right =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx);
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
let result_ndarray = NDArrayObject::broadcast_starmap(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&[left, right],
|
||||||
|
NDArrayOut::NewNDArray { dtype: ctx.primitives.bool },
|
||||||
|
|generator, ctx, scalars| {
|
||||||
|
let left_scalar = scalars[0];
|
||||||
|
let right_scalar = scalars[1];
|
||||||
|
|
||||||
let left_val =
|
let val = gen_cmpop_expr_with_values(
|
||||||
NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
|
generator,
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
ctx,
|
||||||
generator,
|
(Some(left.dtype), left_scalar),
|
||||||
ctx,
|
&[op],
|
||||||
ctx.primitives.bool,
|
&[(Some(right.dtype), right_scalar)],
|
||||||
None,
|
)?
|
||||||
(left_val.as_base_value().into(), false),
|
.unwrap()
|
||||||
(rhs, false),
|
.to_basic_value_enum(
|
||||||
|generator, ctx, (lhs, rhs)| {
|
ctx,
|
||||||
let val = gen_cmpop_expr_with_values(
|
generator,
|
||||||
generator,
|
ctx.primitives.bool,
|
||||||
ctx,
|
)?;
|
||||||
(Some(ndarray_dtype1), lhs),
|
|
||||||
&[op],
|
|
||||||
&[(Some(ndarray_dtype2), rhs)],
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
return Ok(Some(result_ndarray.instance.value.into()));
|
||||||
} else {
|
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
|
||||||
&mut ctx.unifier,
|
|
||||||
if is_ndarray1 { left_ty } else { right_ty },
|
|
||||||
);
|
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
None,
|
|
||||||
(lhs, !is_ndarray1),
|
|
||||||
(rhs, !is_ndarray2),
|
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(Some(ndarray_dtype), lhs),
|
|
||||||
&[op],
|
|
||||||
&[(Some(ndarray_dtype), rhs)],
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue