forked from M-Labs/nac3
This commit is contained in:
parent
86ed0140cb
commit
819e1e4608
|
@ -151,7 +151,7 @@
|
|||
buildInputs = with pkgs; [
|
||||
# build dependencies
|
||||
packages.x86_64-linux.llvm-nac3
|
||||
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
|
||||
llvmPackages_14.clang llvmPackages_14.llvm.out llvmPackages_14.lldb.out # for running nac3standalone demos
|
||||
packages.x86_64-linux.llvm-tools-irrt
|
||||
cargo
|
||||
rustc
|
||||
|
@ -163,7 +163,9 @@
|
|||
clippy
|
||||
pre-commit
|
||||
rustfmt
|
||||
rust-analyzer
|
||||
];
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
};
|
||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||
name = "nac3-dev-shell-msys2";
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
#pragma once
|
||||
|
||||
#include <irrt/int_defs.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
#include <irrt/slice.hpp>
|
||||
|
||||
/*
|
||||
* Notes on `np.transpose(<array>, <axes>)`
|
||||
*
|
||||
* TODO: `axes`, if specified, can actually contain negative indices,
|
||||
* but it is not documented in numpy.
|
||||
*
|
||||
* Supporting it for now.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
namespace ndarray {
|
||||
namespace transpose {
|
||||
namespace util {
|
||||
|
||||
/**
|
||||
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
|
||||
*
|
||||
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
|
||||
* is specified but the user, use this function to do assertions on it.
|
||||
*
|
||||
* @param ndims The number of dimensions of `<array>`
|
||||
* @param num_axes Number of elements in `<axes>` as specified by the user.
|
||||
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
|
||||
* @param axes The user specified `<axes>`.
|
||||
*/
|
||||
template <typename SizeT>
|
||||
void assert_transpose_axes(ErrorContext* errctx, SizeT ndims, SizeT num_axes,
|
||||
const SizeT* axes) {
|
||||
/*
|
||||
* TODO: `axes` can actually contain negative indices, but it is not documented in numpy.
|
||||
*
|
||||
* Supporting it for now.
|
||||
*/
|
||||
|
||||
if (ndims != num_axes) {
|
||||
errctx->set_error(errctx->error_ids->value_error,
|
||||
"axes don't match array");
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Optimize this
|
||||
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
|
||||
for (SizeT i = 0; i < ndims; i++) axe_specified[i] = false;
|
||||
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
||||
if (axis == slice::OUT_OF_BOUNDS) {
|
||||
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
||||
errctx->set_error(
|
||||
errctx->error_ids->value_error,
|
||||
"axis {0} is out of bounds for array of dimension {1}", axis,
|
||||
ndims);
|
||||
return;
|
||||
}
|
||||
|
||||
if (axe_specified[axis]) {
|
||||
errctx->set_error(errctx->error_ids->value_error,
|
||||
"repeated axis in transpose");
|
||||
return;
|
||||
}
|
||||
|
||||
axe_specified[axis] = true;
|
||||
}
|
||||
}
|
||||
} // namespace util
|
||||
|
||||
/**
|
||||
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
|
||||
*
|
||||
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
|
||||
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
|
||||
*
|
||||
* The transpose view created is returned by modifying `dst_ndarray`.
|
||||
*
|
||||
* The caller is responsible for setting up `dst_ndarray` before calling this function.
|
||||
* Here is what this function expects from `dst_ndarray` when called:
|
||||
* - `dst_ndarray->data` does not have to be initialized.
|
||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
|
||||
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||
* When this function call ends:
|
||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||
* - `dst_ndarray->ndims` is unchanged
|
||||
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
|
||||
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
|
||||
*
|
||||
* @param src_ndarray The NDArray to build a transpose view on
|
||||
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
|
||||
* @param num_axes Number of elements in axes, can be undefined if `axes` is nullptr.
|
||||
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is supposed to be `None`.
|
||||
*/
|
||||
template <typename SizeT>
|
||||
void transpose(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray,
|
||||
NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
|
||||
__builtin_assume(src_ndarray->ndims == dst_ndarray->ndims);
|
||||
const auto ndims = src_ndarray->ndims;
|
||||
|
||||
if (axes != nullptr) {
|
||||
util::assert_transpose_axes(errctx, ndims, num_axes, axes);
|
||||
if (errctx->has_error()) return;
|
||||
}
|
||||
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
|
||||
if (axes == nullptr) {
|
||||
// `np.transpose(<array>, axes=None)`
|
||||
|
||||
/*
|
||||
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
|
||||
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
|
||||
* is reversing the order of strides and shape.
|
||||
*
|
||||
* This is a fast implementation to handle this special (but very common) case.
|
||||
*/
|
||||
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
|
||||
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
|
||||
}
|
||||
} else {
|
||||
// `np.transpose(<array>, <axes>)`
|
||||
|
||||
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
// `i` cannot be OUT_OF_BOUNDS because of assertions
|
||||
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
|
||||
|
||||
dst_ndarray->shape[axis] = src_ndarray->shape[i];
|
||||
dst_ndarray->strides[axis] = src_ndarray->strides[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace transpose
|
||||
} // namespace ndarray
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::transpose;
|
||||
void __nac3_ndarray_transpose(ErrorContext* errctx,
|
||||
const NDArray<int32_t>* src_ndarray,
|
||||
NDArray<int32_t>* dst_ndarray, int32_t num_axes,
|
||||
const int32_t* axes) {
|
||||
transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_transpose64(ErrorContext* errctx,
|
||||
const NDArray<int64_t>* src_ndarray,
|
||||
NDArray<int64_t>* dst_ndarray, int64_t num_axes,
|
||||
const int64_t* axes) {
|
||||
transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes);
|
||||
}
|
||||
}
|
|
@ -9,5 +9,6 @@
|
|||
#include <irrt/ndarray/fill.hpp>
|
||||
#include <irrt/ndarray/indexing.hpp>
|
||||
#include <irrt/ndarray/reshape.hpp>
|
||||
#include <irrt/ndarray/transpose.hpp>
|
||||
#include <irrt/slice.hpp>
|
||||
#include <irrt/utils.hpp>
|
|
@ -3,3 +3,4 @@ pub mod basic;
|
|||
pub mod fill;
|
||||
pub mod indexing;
|
||||
pub mod reshape;
|
||||
pub mod transpose;
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
use crate::codegen::{
|
||||
irrt::{
|
||||
error_context::{check_error_context, setup_error_context},
|
||||
util::get_sized_dependent_function_name,
|
||||
},
|
||||
model::*,
|
||||
structs::ndarray::NpArray,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
axes_or_none: Option<ArraySlice<'ctx, SizeTModel<'ctx>, SizeTModel<'ctx>>>,
|
||||
) -> Pointer<'ctx, StructModel<NpArray<'ctx>>> {
|
||||
let sizet = generator.get_sizet(ctx.ctx);
|
||||
let axes_model = PointerModel(sizet);
|
||||
|
||||
let (num_axes, axes) = match axes_or_none {
|
||||
Some(axes) => (axes.num_elements, axes.pointer),
|
||||
None => {
|
||||
// Please refer to the comment in the IRRT implementation
|
||||
(sizet.constant(ctx.ctx, 0), axes_model.nullptr(ctx.ctx))
|
||||
}
|
||||
};
|
||||
|
||||
let perrctx = setup_error_context(ctx);
|
||||
FunctionBuilder::begin(
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_transpose"),
|
||||
)
|
||||
.arg("errctx", perrctx)
|
||||
.arg("src_ndarray", src_ndarray)
|
||||
.arg("dst_ndarray", dst_ndarray)
|
||||
.arg("num_axes", num_axes)
|
||||
.arg("axes", axes)
|
||||
.returning_void();
|
||||
check_error_context(generator, ctx, perrctx);
|
||||
|
||||
dst_ndarray
|
||||
}
|
|
@ -10,9 +10,10 @@ use crate::{
|
|||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||
},
|
||||
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
||||
transpose::call_nac3_ndarray_transpose,
|
||||
},
|
||||
model::*,
|
||||
structs::ndarray::NpArray,
|
||||
structs::{list::List, ndarray::NpArray},
|
||||
util::{array_writer::ArrayWriter, shape::parse_input_shape_arg},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
|
@ -129,3 +130,40 @@ pub fn gen_ndarray_reshape<'ctx>(
|
|||
let reshaped_ndarray = reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?;
|
||||
Ok(reshaped_ndarray.value)
|
||||
}
|
||||
|
||||
pub fn gen_ndarray_transpose<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert!(matches!(args.len(), 1 | 2));
|
||||
|
||||
let sizet = generator.get_sizet(context.ctx);
|
||||
let in_axes_model = PointerModel(StructModel(List { sizet, element: NIntModel(Int32) }));
|
||||
|
||||
// Parse argument #1 ndarray
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?;
|
||||
|
||||
// Parse argument #2 axes (optional)
|
||||
let in_axes = if args.len() == 2 {
|
||||
let in_shape_ty = fun.0.args[1].ty;
|
||||
let in_shape_arg =
|
||||
args[1].1.clone().to_basic_value_enum(context, generator, in_shape_ty)?;
|
||||
|
||||
let in_shape = in_axes_model.review_value(context.ctx, in_shape_arg).unwrap();
|
||||
|
||||
let num_axes = in_shape.gep(context, |f| f.size).load(context, "num_axes");
|
||||
let axes = sizet.array_alloca(context, num_axes, "num_axes");
|
||||
|
||||
Some((in_shape_ty, in_shape_arg))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// call_nac3_ndarray_transpose(generator, ctx, src_ndarray, dst_ndarray, axes_or_none)
|
||||
|
||||
todo!()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
use inkwell::{types::IntType, values::IntValue};
|
||||
|
||||
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
||||
|
||||
/// Convenient structure that looks like
|
||||
/// Python `range`, with a dependent int type.
|
||||
pub struct ForRange<T> {
|
||||
pub start: T,
|
||||
pub stop: T,
|
||||
pub step: T,
|
||||
}
|
||||
|
||||
impl<'ctx, T> ForRange<T>
|
||||
where
|
||||
T: Model<'ctx, Value = IntValue<'ctx>, Type = IntType<'ctx>>,
|
||||
{
|
||||
pub fn end(ctx: &'ctx Context, end: T) -> Self {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn for_range<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
range: ForRange,
|
||||
body: BodyFn,
|
||||
) {
|
||||
}
|
|
@ -1,2 +1,3 @@
|
|||
pub mod array_writer;
|
||||
pub mod control_flow;
|
||||
pub mod shape;
|
||||
|
|
|
@ -496,7 +496,9 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
| PrimDef::FunNpEye
|
||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||
|
||||
PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim),
|
||||
PrimDef::FunNpReshape | PrimDef::FunNpTranspose => {
|
||||
self.build_ndarray_view_functions(prim)
|
||||
}
|
||||
|
||||
PrimDef::FunStr => self.build_str_function(),
|
||||
|
||||
|
@ -1337,7 +1339,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
|
||||
// Build functions related to NDArray views
|
||||
fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]);
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape, PrimDef::FunNpTranspose]);
|
||||
|
||||
match prim {
|
||||
PrimDef::FunNpReshape => {
|
||||
|
@ -1364,6 +1366,56 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
}),
|
||||
)
|
||||
}
|
||||
PrimDef::FunNpTranspose => {
|
||||
/*
|
||||
# NDim has to be known (for checking axes's len)
|
||||
def np_transpose(
|
||||
array: NDArray[DType, NDim],
|
||||
axes: Optional[List[int32]] = None,
|
||||
) -> NDArray[DType, NDim]
|
||||
*/
|
||||
|
||||
// TODO: Allow tuples (or even general iterables in the very far future) on `axes`
|
||||
let optional_axes_ty = self
|
||||
.unifier
|
||||
.subst(
|
||||
self.primitives.option,
|
||||
&VarMap::from([(self.option_tvar.id, self.list_int32)]),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TopLevelDef::Function {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "array".into(),
|
||||
ty: self.primitives.ndarray,
|
||||
default_value: None,
|
||||
},
|
||||
FuncArg {
|
||||
name: "axes".into(),
|
||||
ty: optional_axes_ty,
|
||||
default_value: Some(SymbolValue::OptionNone),
|
||||
},
|
||||
],
|
||||
ret: self.primitives.ndarray,
|
||||
vars: VarMap::default(),
|
||||
})),
|
||||
var_id: vec![self.ndarray_ndims_tvar.id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
numpy_new::view::gen_ndarray_transpose(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ pub enum PrimDef {
|
|||
FunNpEye,
|
||||
FunNpIdentity,
|
||||
FunNpReshape,
|
||||
FunNpTranspose,
|
||||
FunRound,
|
||||
FunRound64,
|
||||
FunNpRound,
|
||||
|
@ -206,6 +207,7 @@ impl PrimDef {
|
|||
PrimDef::FunNpEye => fun("np_eye", None),
|
||||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||
PrimDef::FunRound => fun("round", None),
|
||||
PrimDef::FunRound64 => fun("round64", None),
|
||||
PrimDef::FunNpRound => fun("np_round", None),
|
||||
|
|
|
@ -180,6 +180,7 @@ def patch(module):
|
|||
|
||||
# NumPy view functions
|
||||
module.np_reshape = np.reshape
|
||||
module.np_transpose = np.transpose
|
||||
|
||||
# NumPy Math functions
|
||||
module.np_isnan = np.isnan
|
||||
|
|
Loading…
Reference in New Issue