forked from M-Labs/nac3
core/ndstrides: implement np_reshape()
This commit is contained in:
parent
48d7032b5e
commit
813dad4ed0
@ -7,6 +7,7 @@
|
|||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
#include <irrt/ndarray/indexing.hpp>
|
#include <irrt/ndarray/indexing.hpp>
|
||||||
#include <irrt/ndarray/iter.hpp>
|
#include <irrt/ndarray/iter.hpp>
|
||||||
|
#include <irrt/ndarray/reshape.hpp>
|
||||||
#include <irrt/original.hpp>
|
#include <irrt/original.hpp>
|
||||||
#include <irrt/range.hpp>
|
#include <irrt/range.hpp>
|
||||||
#include <irrt/slice.hpp>
|
#include <irrt/slice.hpp>
|
125
nac3core/irrt/irrt/ndarray/reshape.hpp
Normal file
125
nac3core/irrt/irrt/ndarray/reshape.hpp
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
namespace ndarray
|
||||||
|
{
|
||||||
|
namespace reshape
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
|
||||||
|
*
|
||||||
|
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
|
||||||
|
* modified to contain the resolved dimension.
|
||||||
|
*
|
||||||
|
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
|
||||||
|
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
|
||||||
|
*
|
||||||
|
* @param size The `.size` of `<ndarray>`
|
||||||
|
* @param new_ndims Number of elements in `new_shape`
|
||||||
|
* @param new_shape Target shape to reshape to
|
||||||
|
*/
|
||||||
|
template <typename SizeT> void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT *new_shape)
|
||||||
|
{
|
||||||
|
// Is there a -1 in `new_shape`?
|
||||||
|
bool neg1_exists = false;
|
||||||
|
// Location of -1, only initialized if `neg1_exists` is true
|
||||||
|
SizeT neg1_axis_i;
|
||||||
|
// The computed ndarray size of `new_shape`
|
||||||
|
SizeT new_size = 1;
|
||||||
|
|
||||||
|
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++)
|
||||||
|
{
|
||||||
|
SizeT dim = new_shape[axis_i];
|
||||||
|
if (dim < 0)
|
||||||
|
{
|
||||||
|
if (dim == -1)
|
||||||
|
{
|
||||||
|
if (neg1_exists)
|
||||||
|
{
|
||||||
|
// Multiple `-1` found. Throw an error.
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
|
||||||
|
NO_PARAM, NO_PARAM);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
neg1_exists = true;
|
||||||
|
neg1_axis_i = axis_i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// TODO: What? In `np.reshape` any negative dimensions is
|
||||||
|
// treated like its `-1`.
|
||||||
|
//
|
||||||
|
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
|
||||||
|
//
|
||||||
|
// It is not documented by numpy.
|
||||||
|
// Throw an error for now...
|
||||||
|
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
|
||||||
|
NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
new_size *= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool can_reshape;
|
||||||
|
if (neg1_exists)
|
||||||
|
{
|
||||||
|
// Let `x` be the unknown dimension
|
||||||
|
// Solve `x * <new_size> = <size>`
|
||||||
|
if (new_size == 0 && size == 0)
|
||||||
|
{
|
||||||
|
// `x` has infinitely many solutions
|
||||||
|
can_reshape = false;
|
||||||
|
}
|
||||||
|
else if (new_size == 0 && size != 0)
|
||||||
|
{
|
||||||
|
// `x` has no solutions
|
||||||
|
can_reshape = false;
|
||||||
|
}
|
||||||
|
else if (size % new_size != 0)
|
||||||
|
{
|
||||||
|
// `x` has no integer solutions
|
||||||
|
can_reshape = false;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
can_reshape = true;
|
||||||
|
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
can_reshape = (new_size == size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!can_reshape)
|
||||||
|
{
|
||||||
|
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
|
||||||
|
NO_PARAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace reshape
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t *new_shape)
|
||||||
|
{
|
||||||
|
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t *new_shape)
|
||||||
|
{
|
||||||
|
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
||||||
|
}
|
||||||
|
}
|
@ -1168,3 +1168,22 @@ pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Siz
|
|||||||
);
|
);
|
||||||
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: Instance<'ctx, Int<SizeT>>,
|
||||||
|
new_ndims: Instance<'ctx, Int<SizeT>>,
|
||||||
|
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let name = get_sizet_dependent_function_name(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
|
||||||
|
);
|
||||||
|
CallFunction::begin(generator, ctx, &name)
|
||||||
|
.arg(size)
|
||||||
|
.arg(new_ndims)
|
||||||
|
.arg(new_shape)
|
||||||
|
.returning_void();
|
||||||
|
}
|
||||||
|
@ -2096,292 +2096,6 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
|
|
||||||
///
|
|
||||||
/// * `x1` - `NDArray` to reshape.
|
|
||||||
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
|
||||||
/// Just like numpy, the `shape` argument can be:
|
|
||||||
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
|
||||||
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
|
||||||
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
|
||||||
///
|
|
||||||
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
|
|
||||||
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
shape: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "ndarray_reshape";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (_, shape) = shape;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
|
||||||
|
|
||||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
||||||
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
|
||||||
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
|
|
||||||
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
|
||||||
|
|
||||||
let out = match shape {
|
|
||||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
|
||||||
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
|
||||||
{
|
|
||||||
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
|
|
||||||
|
|
||||||
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
|
||||||
// Check for -1 in dimensions
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(shape_list.load_size(ctx, None), false),
|
|
||||||
|generator, ctx, _, idx| {
|
|
||||||
let ele =
|
|
||||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
|
||||||
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SLT,
|
|
||||||
ele,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|_, ctx| -> Result<Option<IntValue>, String> {
|
|
||||||
let num_neg_value =
|
|
||||||
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
|
||||||
let num_neg_value = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_add(
|
|
||||||
num_neg_value,
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
|
|
||||||
Ok(None)
|
|
||||||
},
|
|
||||||
|_, ctx| {
|
|
||||||
let acc_value =
|
|
||||||
ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
||||||
let acc_value =
|
|
||||||
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
|
|
||||||
ctx.builder.build_store(acc, acc_value).unwrap();
|
|
||||||
Ok(None)
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
||||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
|
||||||
// Generate the output shape by filling -1 with `rem`
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&shape_list,
|
|
||||||
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|
|
||||||
|generator, ctx, shape_list, idx| {
|
|
||||||
let dim =
|
|
||||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
|
||||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
Ok(gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SLT,
|
|
||||||
dim,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|_, _| Ok(Some(rem)),
|
|
||||||
|_, _| Ok(Some(dim)),
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
BasicValueEnum::StructValue(shape_tuple) => {
|
|
||||||
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
|
||||||
|
|
||||||
let ndims = shape_tuple.get_type().count_fields();
|
|
||||||
// Check for -1 in dims
|
|
||||||
for dim_i in 0..ndims {
|
|
||||||
let dim = ctx
|
|
||||||
.builder
|
|
||||||
.build_extract_value(shape_tuple, dim_i, "")
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
|
||||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SLT,
|
|
||||||
dim,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|_, ctx| -> Result<Option<IntValue>, String> {
|
|
||||||
let num_negs =
|
|
||||||
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
|
||||||
let num_negs = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
|
|
||||||
.unwrap();
|
|
||||||
ctx.builder.build_store(num_neg, num_negs).unwrap();
|
|
||||||
Ok(None)
|
|
||||||
},
|
|
||||||
|_, ctx| {
|
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
||||||
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
|
|
||||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
|
||||||
Ok(None)
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
|
||||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
|
||||||
let mut shape = Vec::with_capacity(ndims as usize);
|
|
||||||
|
|
||||||
// Reconstruct shape filling negatives with rem
|
|
||||||
for dim_i in 0..ndims {
|
|
||||||
let dim = ctx
|
|
||||||
.builder
|
|
||||||
.build_extract_value(shape_tuple, dim_i, "")
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
|
||||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
let dim = gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SLT,
|
|
||||||
dim,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|_, _| Ok(Some(rem)),
|
|
||||||
|_, _| Ok(Some(dim)),
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
|
||||||
shape.push(dim);
|
|
||||||
}
|
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
|
||||||
}
|
|
||||||
BasicValueEnum::IntValue(shape_int) => {
|
|
||||||
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
|
||||||
let shape_int = gen_if_else_expr_callback(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SLT,
|
|
||||||
shape_int,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap())
|
|
||||||
},
|
|
||||||
|_, _| Ok(Some(n_sz)),
|
|
||||||
|_, ctx| {
|
|
||||||
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
|
|
||||||
},
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Only allow one dimension to be negative
|
|
||||||
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"can only specify one unknown dimension",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// The new shape must be compatible with the old shape
|
|
||||||
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"cannot reshape array of size {0} into provided shape of size {1}",
|
|
||||||
[Some(n_sz), Some(out_sz), None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(n_sz, false),
|
|
||||||
|generator, ctx, _, idx| {
|
|
||||||
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
|
||||||
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(out.as_base_value().into())
|
|
||||||
} else {
|
|
||||||
unreachable!(
|
|
||||||
"{FN_NAME}() not supported for '{}'",
|
|
||||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.dot`.
|
/// Generates LLVM IR for `ndarray.dot`.
|
||||||
/// Calculate inner product of two vectors or literals
|
/// Calculate inner product of two vectors or literals
|
||||||
/// For matrix multiplication use `np_matmul`
|
/// For matrix multiplication use `np_matmul`
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
use crate::codegen::{
|
||||||
|
irrt::call_nac3_ndarray_reshape_resolve_and_check_new_shape, model::*, CodeGenContext,
|
||||||
|
CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
use super::{indexing::RustNDIndex, NDArrayObject};
|
use super::{indexing::RustNDIndex, NDArrayObject};
|
||||||
|
|
||||||
@ -26,4 +29,61 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||||||
*self
|
*self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a reshaped view on this ndarray like `np.reshape()`.
|
||||||
|
///
|
||||||
|
/// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be modified as a result.
|
||||||
|
///
|
||||||
|
/// If reshape without copying is impossible, this function will allocate a new ndarray and copy contents.
|
||||||
|
///
|
||||||
|
/// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`].
|
||||||
|
/// * `new_shape` - The target shape to do `np.reshape()`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
new_ndims: u64,
|
||||||
|
new_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) -> Self {
|
||||||
|
// TODO: The current criterion for whether to do a full copy or not is by checking `is_c_contiguous`,
|
||||||
|
// but this is not optimal - there are cases when the ndarray is not contiguous but could be reshaped
|
||||||
|
// without copying data. Look into how numpy does it.
|
||||||
|
|
||||||
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
|
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
|
||||||
|
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
||||||
|
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||||
|
|
||||||
|
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims);
|
||||||
|
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
|
||||||
|
|
||||||
|
// Reolsve negative indices
|
||||||
|
let size = self.size(generator, ctx);
|
||||||
|
let dst_ndims = dst_ndarray.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let dst_shape = dst_ndarray.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
call_nac3_ndarray_reshape_resolve_and_check_new_shape(
|
||||||
|
generator, ctx, size, dst_ndims, dst_shape,
|
||||||
|
);
|
||||||
|
|
||||||
|
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
|
||||||
|
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
||||||
|
|
||||||
|
// Inserting into then_bb: reshape is possible without copying
|
||||||
|
ctx.builder.position_at_end(then_bb);
|
||||||
|
dst_ndarray.set_strides_contiguous(generator, ctx);
|
||||||
|
dst_ndarray.instance.set(ctx, |f| f.data, self.instance.get(generator, ctx, |f| f.data));
|
||||||
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
|
// Inserting into else_bb: reshape is impossible without copying
|
||||||
|
ctx.builder.position_at_end(else_bb);
|
||||||
|
dst_ndarray.create_data(generator, ctx);
|
||||||
|
dst_ndarray.copy_data_from(generator, ctx, *self);
|
||||||
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
|
// Reposition for continuation
|
||||||
|
ctx.builder.position_at_end(end_bb);
|
||||||
|
|
||||||
|
dst_ndarray
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
|
|
||||||
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
use helper::{debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDefDetails};
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
@ -9,6 +9,7 @@ use inkwell::{
|
|||||||
IntPredicate,
|
IntPredicate,
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
use numpy::unpack_ndarray_var_tys;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -17,7 +18,10 @@ use crate::{
|
|||||||
classes::{ProxyValue, RangeValue},
|
classes::{ProxyValue, RangeValue},
|
||||||
model::*,
|
model::*,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
object::{any::AnyObject, ndarray::NDArrayObject},
|
object::{
|
||||||
|
any::AnyObject,
|
||||||
|
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
||||||
|
},
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
@ -1467,27 +1471,25 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||||
|
|
||||||
|
let in_ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||||
|
&[self.primitives.ndarray],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
match prim {
|
match prim {
|
||||||
PrimDef::FunNpTranspose => {
|
PrimDef::FunNpTranspose => create_fn_by_codegen(
|
||||||
let ndarray_ty = self.unifier.get_fresh_var_with_range(
|
self.unifier,
|
||||||
&[self.ndarray_num_ty],
|
&into_var_map([in_ndarray_ty]),
|
||||||
Some("T".into()),
|
prim.name(),
|
||||||
None,
|
in_ndarray_ty.ty,
|
||||||
);
|
&[(in_ndarray_ty.ty, "x")],
|
||||||
create_fn_by_codegen(
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
self.unifier,
|
let arg_ty = fun.0.args[0].ty;
|
||||||
&into_var_map([ndarray_ty]),
|
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
prim.name(),
|
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
||||||
self.ndarray_num_ty,
|
}),
|
||||||
&[(self.ndarray_num_ty, "x")],
|
),
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let arg_ty = fun.0.args[0].ty;
|
|
||||||
let arg_val =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
|
||||||
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||||
// the `param_ty` for `create_fn_by_codegen`.
|
// the `param_ty` for `create_fn_by_codegen`.
|
||||||
@ -1495,20 +1497,42 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||||
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||||
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||||
PrimDef::FunNpReshape => create_fn_by_codegen(
|
PrimDef::FunNpReshape => {
|
||||||
self.unifier,
|
let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding
|
||||||
&VarMap::new(),
|
|
||||||
prim.name(),
|
create_fn_by_codegen(
|
||||||
self.ndarray_num_ty,
|
self.unifier,
|
||||||
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
&VarMap::new(),
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
prim.name(),
|
||||||
let x1_ty = fun.0.args[0].ty;
|
ret_ty,
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
&[
|
||||||
let x2_ty = fun.0.args[1].ty;
|
(in_ndarray_ty.ty, "x"),
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding
|
||||||
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
],
|
||||||
}),
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
),
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
|
let ndarray_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
|
||||||
|
let shape_ty = fun.0.args[1].ty;
|
||||||
|
let shape_val =
|
||||||
|
args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
|
||||||
|
let ndarray = AnyObject { value: ndarray_val, ty: ndarray_ty };
|
||||||
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
|
let shape = AnyObject { value: shape_val, ty: shape_ty };
|
||||||
|
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||||
|
|
||||||
|
// The ndims after reshaping is gotten from the return type of the call.
|
||||||
|
let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
|
|
||||||
|
let new_ndarray = ndarray.reshape_or_copy(generator, ctx, ndims, shape);
|
||||||
|
Ok(Some(new_ndarray.instance.value.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user