From 3bd344bd2b2557d2004bc88362180c8110a9b9c1 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 21 Aug 2024 10:20:20 +0800 Subject: [PATCH] core/ndstrides: implement cmpop --- nac3core/src/codegen/expr.rs | 113 ++++++++++++----------------------- 1 file changed, 37 insertions(+), 76 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 77284088..732993d7 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -28,8 +28,8 @@ use super::object::{ use crate::{ codegen::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, - ProxyValue, RangeValue, UntypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue, + RangeValue, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, @@ -39,7 +39,7 @@ use crate::{ call_int_umin, call_memcpy_generic, }, macros::codegen_unreachable, - need_sret, numpy, + need_sret, object::ndarray::{NDArrayOut, ScalarOrNDArray}, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, @@ -48,7 +48,7 @@ use crate::{ CodeGenContext, CodeGenTask, CodeGenerator, }, symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -1848,85 +1848,46 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; - let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; + let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) }; + let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) }; let op = ops[0]; - let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left = AnyObject { value: left, ty: left_ty }; + let left = + ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx); - return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let right = AnyObject { value: right, ty: right_ty }; + let right = + ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let result_ndarray = NDArrayObject::broadcast_starmap( + generator, + ctx, + &[left, right], + NDArrayOut::NewNDArray { dtype: ctx.primitives.bool }, + |generator, ctx, scalars| { + let left_scalar = scalars[0]; + let right_scalar = scalars[1]; - let left_val = - NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_val.as_base_value().into(), false), - (rhs, false), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype1), lhs), - &[op], - &[(Some(ndarray_dtype2), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left.dtype), left_scalar), + &[op], + &[(Some(right.dtype), right_scalar)], + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { left_ty } else { right_ty }, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (lhs, !is_ndarray1), - (rhs, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype), lhs), - &[op], - &[(Some(ndarray_dtype), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; - - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; - - Ok(Some(res.as_base_value().into())) - }; + return Ok(Some(result_ndarray.instance.value.into())); } }