forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement ndarray.fill() and .copy()

This commit is contained in:
lyken 2024-08-20 15:18:07 +08:00
parent 8fe8ccf200
commit ee58cf3fc3
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
2 changed files with 11 additions and 48 deletions

View File

@ -1958,20 +1958,14 @@ pub fn gen_ndarray_copy<'ctx>(
assert!(obj.is_some()); assert!(obj.is_some());
assert!(args.is_empty()); assert!(args.is_empty());
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg = let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
ndarray_copy_impl( let this = AnyObject { value: this_arg, ty: this_ty };
generator, let this = NDArrayObject::from_object(generator, context, this);
context, let ndarray = this.make_copy(generator, context);
this_elem_ty, Ok(ndarray.instance.value)
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
} }
/// Generates LLVM IR for `ndarray.fill`. /// Generates LLVM IR for `ndarray.fill`.
@ -1985,48 +1979,15 @@ pub fn gen_ndarray_fill<'ctx>(
assert!(obj.is_some()); assert!(obj.is_some());
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj let this_arg =
.as_ref() obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let value_ty = fun.0.args[0].ty; let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
ndarray_fill_flattened( let this = AnyObject { value: this_arg, ty: this_ty };
generator, let this = NDArrayObject::from_object(generator, context, this);
context, this.fill(generator, context, value_arg);
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
codegen_unreachable!(ctx)
};
Ok(value)
},
)?;
Ok(()) Ok(())
} }

View File

@ -418,6 +418,8 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>, value: BasicValueEnum<'ctx>,
) { ) {
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
// Probably best to implement in IRRT.
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx); let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap(); ctx.builder.build_store(p, value).unwrap();