forked from M-Labs/nac3
[core] codegen/ndarray: Reimplement np_{copy,fill}
Based on 18db85fa: core/ndstrides: implement ndarray.fill() and .copy()
This commit is contained in:
parent
acb437919d
commit
9ffa2d6552
@ -1315,19 +1315,13 @@ pub fn gen_ndarray_copy<'ctx>(
|
|||||||
assert!(args.is_empty());
|
assert!(args.is_empty());
|
||||||
|
|
||||||
let this_ty = obj.as_ref().unwrap().0;
|
let this_ty = obj.as_ref().unwrap().0;
|
||||||
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
|
||||||
let this_arg =
|
let this_arg =
|
||||||
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||||
|
|
||||||
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty);
|
let this = NDArrayType::from_unifier_type(generator, context, this_ty)
|
||||||
|
.map_value(this_arg.into_pointer_value(), None);
|
||||||
ndarray_copy_impl(
|
let ndarray = this.make_copy(generator, context);
|
||||||
generator,
|
Ok(ndarray.as_base_value())
|
||||||
context,
|
|
||||||
this_elem_ty,
|
|
||||||
llvm_this_ty.map_value(this_arg.into_pointer_value(), None),
|
|
||||||
)
|
|
||||||
.map(NDArrayValue::into)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.fill`.
|
/// Generates LLVM IR for `ndarray.fill`.
|
||||||
@ -1342,47 +1336,14 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
let this_ty = obj.as_ref().unwrap().0;
|
let this_ty = obj.as_ref().unwrap().0;
|
||||||
let this_arg = obj
|
let this_arg =
|
||||||
.as_ref()
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||||
.unwrap()
|
|
||||||
.1
|
|
||||||
.clone()
|
|
||||||
.to_basic_value_enum(context, generator, this_ty)?
|
|
||||||
.into_pointer_value();
|
|
||||||
let value_ty = fun.0.args[0].ty;
|
let value_ty = fun.0.args[0].ty;
|
||||||
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
|
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
|
||||||
|
|
||||||
let llvm_this_ty = NDArrayType::from_unifier_type(generator, context, this_ty);
|
let this = NDArrayType::from_unifier_type(generator, context, this_ty)
|
||||||
|
.map_value(this_arg.into_pointer_value(), None);
|
||||||
ndarray_fill_flattened(
|
this.fill(generator, context, value_arg);
|
||||||
generator,
|
|
||||||
context,
|
|
||||||
llvm_this_ty.map_value(this_arg, None),
|
|
||||||
|generator, ctx, _| {
|
|
||||||
let value = if value_arg.is_pointer_value() {
|
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
|
||||||
|
|
||||||
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
|
|
||||||
|
|
||||||
call_memcpy_generic(
|
|
||||||
ctx,
|
|
||||||
copy,
|
|
||||||
value_arg.into_pointer_value(),
|
|
||||||
value_arg.get_type().size_of().map(Into::into).unwrap(),
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
copy.into()
|
|
||||||
} else if value_arg.is_int_value() || value_arg.is_float_value() {
|
|
||||||
value_arg
|
|
||||||
} else {
|
|
||||||
codegen_unreachable!(ctx)
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(value)
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -407,6 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
value: BasicValueEnum<'ctx>,
|
value: BasicValueEnum<'ctx>,
|
||||||
) {
|
) {
|
||||||
|
// TODO: It is possible to optimize this by exploiting contiguous strides with memset.
|
||||||
|
// Probably best to implement in IRRT.
|
||||||
self.foreach(generator, ctx, |_, ctx, _, nditer| {
|
self.foreach(generator, ctx, |_, ctx, _, nditer| {
|
||||||
let p = nditer.get_pointer(ctx);
|
let p = nditer.get_pointer(ctx);
|
||||||
ctx.builder.build_store(p, value).unwrap();
|
ctx.builder.build_store(p, value).unwrap();
|
||||||
|
Loading…
Reference in New Issue
Block a user