From 18db85fa7be7e70e55b1f31d1c6e2df1a627bc83 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 15:18:07 +0800 Subject: [PATCH] core/ndstrides: implement ndarray.fill() and .copy() --- nac3core/src/codegen/numpy.rs | 57 ++++------------------ nac3core/src/codegen/object/ndarray/mod.rs | 2 + 2 files changed, 11 insertions(+), 48 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ecc37501..2f46cd1a 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1957,20 +1957,14 @@ pub fn gen_ndarray_copy<'ctx>( assert!(obj.is_some()); assert!(args.is_empty()); - let llvm_usize = generator.get_size_type(context.ctx); - let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; - ndarray_copy_impl( - generator, - context, - this_elem_ty, - NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None), - ) - .map(NDArrayValue::into) + let this = AnyObject { value: this_arg, ty: this_ty }; + let this = NDArrayObject::from_object(generator, context, this); + let ndarray = this.make_copy(generator, context); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.fill`. @@ -1984,48 +1978,15 @@ pub fn gen_ndarray_fill<'ctx>( assert!(obj.is_some()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); - let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)? - .into_pointer_value(); + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; - ndarray_fill_flattened( - generator, - context, - NDArrayValue::from_ptr_val(this_arg, llvm_usize, None), - |generator, ctx, _| { - let value = if value_arg.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - value_arg.into_pointer_value(), - value_arg.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if value_arg.is_int_value() || value_arg.is_float_value() { - value_arg - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - }, - )?; - + let this = AnyObject { value: this_arg, ty: this_ty }; + let this = NDArrayObject::from_object(generator, context, this); + this.fill(generator, context, value_arg); Ok(()) } diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 85643a60..84c0f2a9 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -417,6 +417,8 @@ impl<'ctx> NDArrayObject<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>, ) { + // TODO: It is possible to optimize this by exploiting contiguous strides with memset. + // Probably best to implement in IRRT. self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { let p = nditer.get_pointer(generator, ctx); ctx.builder.build_store(p, value).unwrap();