1
0
forked from M-Labs/nac3

[core] codegen/ndarray: Reimplement np_{copy,fill}

Based on 18db85fa: core/ndstrides: implement ndarray.fill() and .copy()
This commit is contained in:
David Mak 2024-12-18 09:53:00 +08:00
parent acb437919d
commit 9ffa2d6552
2 changed files with 11 additions and 48 deletions

View File

@ -1315,19 +1315,13 @@ pub fn gen_ndarray_copy<'ctx>(
assert!(args.is_empty()); assert!(args.is_empty());
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg = let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); let this = NDArrayType::from_unifier_type(generator, context, this_ty)
.map_value(this_arg.into_pointer_value(), None);
ndarray_copy_impl( let ndarray = this.make_copy(generator, context);
generator, Ok(ndarray.as_base_value())
context,
this_elem_ty,
llvm_this_ty.map_value(this_arg.into_pointer_value(), None),
)
.map(NDArrayValue::into)
} }
/// Generates LLVM IR for `ndarray.fill`. /// Generates LLVM IR for `ndarray.fill`.
@ -1342,47 +1336,14 @@ pub fn gen_ndarray_fill<'ctx>(
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj let this_arg =
.as_ref() obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let value_ty = fun.0.args[0].ty; let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty); let this = NDArrayType::from_unifier_type(generator, context, this_ty)
.map_value(this_arg.into_pointer_value(), None);
ndarray_fill_flattened( this.fill(generator, context, value_arg);
generator,
context,
llvm_this_ty.map_value(this_arg, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
codegen_unreachable!(ctx)
};
Ok(value)
},
)?;
Ok(()) Ok(())
} }

View File

@ -407,6 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>, value: BasicValueEnum<'ctx>,
) { ) {
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
// Probably best to implement in IRRT.
self.foreach(generator, ctx, |_, ctx, _, nditer| { self.foreach(generator, ctx, |_, ctx, _, nditer| {
let p = nditer.get_pointer(ctx); let p = nditer.get_pointer(ctx);
ctx.builder.build_store(p, value).unwrap(); ctx.builder.build_store(p, value).unwrap();