forked from M-Labs/nac3
[core] codegen/ndarray: Reimplement np_{copy,fill}
Based on 18db85fa: core/ndstrides: implement ndarray.fill() and .copy()
This commit is contained in:
parent
acb437919d
commit
9ffa2d6552
@ -1315,19 +1315,13 @@ pub fn gen_ndarray_copy<'ctx>(
|
||||
assert!(args.is_empty());
|
||||
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
||||
let this_arg =
|
||||
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||
|
||||
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty);
|
||||
|
||||
ndarray_copy_impl(
|
||||
generator,
|
||||
context,
|
||||
this_elem_ty,
|
||||
llvm_this_ty.map_value(this_arg.into_pointer_value(), None),
|
||||
)
|
||||
.map(NDArrayValue::into)
|
||||
let this = NDArrayType::from_unifier_type(generator, context, this_ty)
|
||||
.map_value(this_arg.into_pointer_value(), None);
|
||||
let ndarray = this.make_copy(generator, context);
|
||||
Ok(ndarray.as_base_value())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.fill`.
|
||||
@ -1342,47 +1336,14 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let this_arg = obj
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.1
|
||||
.clone()
|
||||
.to_basic_value_enum(context, generator, this_ty)?
|
||||
.into_pointer_value();
|
||||
let this_arg =
|
||||
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||
let value_ty = fun.0.args[0].ty;
|
||||
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
|
||||
|
||||
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty);
|
||||
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
context,
|
||||
llvm_this_ty.map_value(this_arg, None),
|
||||
|generator, ctx, _| {
|
||||
let value = if value_arg.is_pointer_value() {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
copy,
|
||||
value_arg.into_pointer_value(),
|
||||
value_arg.get_type().size_of().map(Into::into).unwrap(),
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
copy.into()
|
||||
} else if value_arg.is_int_value() || value_arg.is_float_value() {
|
||||
value_arg
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
},
|
||||
)?;
|
||||
|
||||
let this = NDArrayType::from_unifier_type(generator, context, this_ty)
|
||||
.map_value(this_arg.into_pointer_value(), None);
|
||||
this.fill(generator, context, value_arg);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -407,6 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
|
||||
// Probably best to implement in IRRT.
|
||||
self.foreach(generator, ctx, |_, ctx, _, nditer| {
|
||||
let p = nditer.get_pointer(ctx);
|
||||
ctx.builder.build_store(p, value).unwrap();
|
||||
|
Loading…
Reference in New Issue
Block a user