forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement cmpop

This commit is contained in:
lyken 2024-08-21 10:20:20 +08:00
parent 5f143d2f2f
commit 9bf0e2cbf4
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
1 changed files with 37 additions and 76 deletions

View File

@ -1,8 +1,8 @@
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue,
ProxyValue, RangeValue, UntypedArrayLikeAccessor, RangeValue, UntypedArrayLikeAccessor,
}, },
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
@ -11,7 +11,7 @@ use crate::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic, call_int_umin, call_memcpy_generic,
}, },
need_sret, numpy, need_sret,
object::ndarray::{NDArrayOut, ScalarOrNDArray}, object::ndarray::{NDArrayOut, ScalarOrNDArray},
stmt::{ stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
@ -20,7 +20,7 @@ use crate::{
CodeGenContext, CodeGenTask, CodeGenerator, CodeGenContext, CodeGenTask, CodeGenerator,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
magic_methods::{Binop, BinopVariant, HasOpInfo}, magic_methods::{Binop, BinopVariant, HasOpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -1842,39 +1842,33 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let (Some(left_ty), left) = left else { unreachable!() };
let (Some(right_ty), right) = comparators[0] else { unreachable!() };
let (Some(left_ty), lhs) = left else { unreachable!() };
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
let op = ops[0]; let op = ops[0];
let is_ndarray1 = let left = AnyObject { value: left, ty: left_ty };
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let left =
let is_ndarray2 = ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx);
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 { let right = AnyObject { value: right, ty: right_ty };
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); let right =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); let result_ndarray = NDArrayObject::broadcast_starmap(
let left_val =
NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
let res = numpy::ndarray_elementwise_binop_impl(
generator, generator,
ctx, ctx,
ctx.primitives.bool, &[left, right],
None, NDArrayOut::NewNDArray { dtype: ctx.primitives.bool },
(left_val.as_base_value().into(), false), |generator, ctx, scalars| {
(rhs, false), let left_scalar = scalars[0];
|generator, ctx, (lhs, rhs)| { let right_scalar = scalars[1];
let val = gen_cmpop_expr_with_values( let val = gen_cmpop_expr_with_values(
generator, generator,
ctx, ctx,
(Some(ndarray_dtype1), lhs), (Some(left.dtype), left_scalar),
&[op], &[op],
&[(Some(ndarray_dtype2), rhs)], &[(Some(right.dtype), right_scalar)],
)? )?
.unwrap() .unwrap()
.to_basic_value_enum( .to_basic_value_enum(
@ -1887,40 +1881,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}, },
)?; )?;
Ok(Some(res.as_base_value().into())) return Ok(Some(result_ndarray.instance.value.into()));
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
if is_ndarray1 { left_ty } else { right_ty },
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
(lhs, !is_ndarray1),
(rhs, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype), lhs),
&[op],
&[(Some(ndarray_dtype), rhs)],
)?
.unwrap()
.to_basic_value_enum(
ctx,
generator,
ctx.primitives.bool,
)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
Ok(Some(res.as_base_value().into()))
};
} }
} }