core/ndstrides: add ScalarOrNDArray::to_ndarray and NDArrayObject::make_unsized

This commit is contained in:
lyken 2024-08-20 17:01:25 +08:00
parent 1953f42dca
commit 64b47c3144
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
1 changed files with 30 additions and 0 deletions

View File

@ -460,6 +460,22 @@ impl<'ctx> NDArrayObject<'ctx> {
TupleObject::from_objects(generator, ctx, objects)
}
/// Create an unsized ndarray to contain `object`.
pub fn make_unsized<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> NDArrayObject<'ctx> {
// We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(object.value.get_type(), "make_unsized").unwrap();
ctx.builder.build_store(data, object.value).unwrap();
let data = Ptr(Int(Byte)).pointer_cast(generator, ctx, data);
let ndarray = NDArrayObject::alloca(generator, ctx, object.ty, 0);
ndarray.instance.set(ctx, |f| f.data, data);
ndarray
}
}
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
@ -500,4 +516,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
}
}
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
/// - If this is an ndarray, the ndarray is returned.
/// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`].
pub fn to_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar),
}
}
}