From 9ffa2d6552e5052aba16bdba9b9f081efa138c68 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 18 Dec 2024 09:53:00 +0800 Subject: [PATCH] [core] codegen/ndarray: Reimplement np_{copy,fill} Based on 18db85fa: core/ndstrides: implement ndarray.fill() and .copy() --- nac3core/src/codegen/numpy.rs | 57 ++++------------------ nac3core/src/codegen/values/ndarray/mod.rs | 2 + 2 files changed, 11 insertions(+), 48 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 703d03ec..2f899f95 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1315,19 +1315,13 @@ pub fn gen_ndarray_copy<'ctx>( assert!(args.is_empty()); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_copy_impl( - generator, - context, - this_elem_ty, - llvm_this_ty.map_value(this_arg.into_pointer_value(), None), - ) - .map(NDArrayValue::into) + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_value(this_arg.into_pointer_value(), None); + let ndarray = this.make_copy(generator, context); + Ok(ndarray.as_base_value()) } /// Generates LLVM IR for `ndarray.fill`. @@ -1342,47 +1336,14 @@ pub fn gen_ndarray_fill<'ctx>( assert_eq!(args.len(), 1); let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)? - .into_pointer_value(); + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; - let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); - - ndarray_fill_flattened( - generator, - context, - llvm_this_ty.map_value(this_arg, None), - |generator, ctx, _| { - let value = if value_arg.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - value_arg.into_pointer_value(), - value_arg.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if value_arg.is_int_value() || value_arg.is_float_value() { - value_arg - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - }, - )?; - + let this = NDArrayType::from_unifier_type(generator, context, this_ty) + .map_value(this_arg.into_pointer_value(), None); + this.fill(generator, context, value_arg); Ok(()) } diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index d4a460a5..2907445e 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -407,6 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>, ) { + // TODO: It is possible to optimize this by exploiting contiguous strides with memset. + // Probably best to implement in IRRT. self.foreach(generator, ctx, |_, ctx, _, nditer| { let p = nditer.get_pointer(ctx); ctx.builder.build_store(p, value).unwrap();