From f43523ec725922d9889c051d4242ce8b2abb26b0 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:37:17 +0800 Subject: [PATCH] [core] codegen: Reimplement ndarray unary op Based on bb992704: core/ndstrides: implement unary op --- nac3core/src/codegen/expr.rs | 18 +++++------ nac3core/src/codegen/numpy.rs | 59 ----------------------------------- 2 files changed, 8 insertions(+), 69 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e33cebfd..16b54aa4 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1800,10 +1800,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( _ => val.into(), } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None); + let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty) + .map_value(val.into_pointer_value(), None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -1821,20 +1821,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op }; - let res = numpy::ndarray_elementwise_unaryop_impl( + let mapped_ndarray = ndarray.map( generator, ctx, - ndarray_dtype, - None, - val, - |generator, ctx, val| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() }, + |generator, ctx, scalar| { + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))? + .map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype)) .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) }, )?; - res.as_base_value().into() + mapped_ndarray.as_base_value().into() } else { unimplemented!() })) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index fdbb716b..d02103ab 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -195,28 +195,6 @@ where }) } -fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - src: NDArrayValue<'ctx>, - dest: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { - let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; - - map_fn(generator, ctx, elem) - }) -} - /// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of /// the target `ndarray`. fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( @@ -614,43 +592,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) } -pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - operand: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - let res = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &operand, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - }); - - ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { - map_fn(generator, ctx, elem) - })?; - - Ok(res) -} - /// LLVM-typed implementation for computing elementwise binary operations on two input operands. /// /// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output