forked from M-Labs/nac3
1
0
Fork 0
This commit is contained in:
lyken 2024-07-27 12:35:22 +08:00
parent 86ed0140cb
commit 819e1e4608
11 changed files with 335 additions and 4 deletions

View File

@ -151,7 +151,7 @@
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos llvmPackages_14.clang llvmPackages_14.llvm.out llvmPackages_14.lldb.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc
@ -163,7 +163,9 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -0,0 +1,162 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
/*
* Notes on `np.transpose(<array>, <axes>)`
*
* TODO: `axes`, if specified, can actually contain negative indices,
* but it is not documented in numpy.
*
* Supporting it for now.
*/
namespace {
namespace ndarray {
namespace transpose {
namespace util {
/**
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
*
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
* is specified but the user, use this function to do assertions on it.
*
* @param ndims The number of dimensions of `<array>`
* @param num_axes Number of elements in `<axes>` as specified by the user.
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
* @param axes The user specified `<axes>`.
*/
template <typename SizeT>
void assert_transpose_axes(ErrorContext* errctx, SizeT ndims, SizeT num_axes,
const SizeT* axes) {
/*
* TODO: `axes` can actually contain negative indices, but it is not documented in numpy.
*
* Supporting it for now.
*/
if (ndims != num_axes) {
errctx->set_error(errctx->error_ids->value_error,
"axes don't match array");
return;
}
// TODO: Optimize this
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
for (SizeT i = 0; i < ndims; i++) axe_specified[i] = false;
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == slice::OUT_OF_BOUNDS) {
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
errctx->set_error(
errctx->error_ids->value_error,
"axis {0} is out of bounds for array of dimension {1}", axis,
ndims);
return;
}
if (axe_specified[axis]) {
errctx->set_error(errctx->error_ids->value_error,
"repeated axis in transpose");
return;
}
axe_specified[axis] = true;
}
}
} // namespace util
/**
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
*
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
*
* The transpose view created is returned by modifying `dst_ndarray`.
*
* The caller is responsible for setting up `dst_ndarray` before calling this function.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
*
* @param src_ndarray The NDArray to build a transpose view on
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
* @param num_axes Number of elements in axes, can be undefined if `axes` is nullptr.
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is supposed to be `None`.
*/
template <typename SizeT>
void transpose(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray,
NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
__builtin_assume(src_ndarray->ndims == dst_ndarray->ndims);
const auto ndims = src_ndarray->ndims;
if (axes != nullptr) {
util::assert_transpose_axes(errctx, ndims, num_axes, axes);
if (errctx->has_error()) return;
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
if (axes == nullptr) {
// `np.transpose(<array>, axes=None)`
/*
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
* is reversing the order of strides and shape.
*
* This is a fast implementation to handle this special (but very common) case.
*/
for (SizeT axis = 0; axis < ndims; axis++) {
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
}
} else {
// `np.transpose(<array>, <axes>)`
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
for (SizeT axis = 0; axis < ndims; axis++) {
// `i` cannot be OUT_OF_BOUNDS because of assertions
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
dst_ndarray->shape[axis] = src_ndarray->shape[i];
dst_ndarray->strides[axis] = src_ndarray->strides[i];
}
}
}
} // namespace transpose
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::transpose;
void __nac3_ndarray_transpose(ErrorContext* errctx,
const NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray, int32_t num_axes,
const int32_t* axes) {
transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes);
}
void __nac3_ndarray_transpose64(ErrorContext* errctx,
const NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray, int64_t num_axes,
const int64_t* axes) {
transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes);
}
}

View File

@ -9,5 +9,6 @@
#include <irrt/ndarray/fill.hpp> #include <irrt/ndarray/fill.hpp>
#include <irrt/ndarray/indexing.hpp> #include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/reshape.hpp> #include <irrt/ndarray/reshape.hpp>
#include <irrt/ndarray/transpose.hpp>
#include <irrt/slice.hpp> #include <irrt/slice.hpp>
#include <irrt/utils.hpp> #include <irrt/utils.hpp>

View File

@ -3,3 +3,4 @@ pub mod basic;
pub mod fill; pub mod fill;
pub mod indexing; pub mod indexing;
pub mod reshape; pub mod reshape;
pub mod transpose;

View File

@ -0,0 +1,43 @@
use crate::codegen::{
irrt::{
error_context::{check_error_context, setup_error_context},
util::get_sized_dependent_function_name,
},
model::*,
structs::ndarray::NpArray,
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
axes_or_none: Option<ArraySlice<'ctx, SizeTModel<'ctx>, SizeTModel<'ctx>>>,
) -> Pointer<'ctx, StructModel<NpArray<'ctx>>> {
let sizet = generator.get_sizet(ctx.ctx);
let axes_model = PointerModel(sizet);
let (num_axes, axes) = match axes_or_none {
Some(axes) => (axes.num_elements, axes.pointer),
None => {
// Please refer to the comment in the IRRT implementation
(sizet.constant(ctx.ctx, 0), axes_model.nullptr(ctx.ctx))
}
};
let perrctx = setup_error_context(ctx);
FunctionBuilder::begin(
ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_transpose"),
)
.arg("errctx", perrctx)
.arg("src_ndarray", src_ndarray)
.arg("dst_ndarray", dst_ndarray)
.arg("num_axes", num_axes)
.arg("axes", axes)
.returning_void();
check_error_context(generator, ctx, perrctx);
dst_ndarray
}

View File

@ -10,9 +10,10 @@ use crate::{
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
}, },
reshape::call_nac3_ndarray_resolve_and_check_new_shape, reshape::call_nac3_ndarray_resolve_and_check_new_shape,
transpose::call_nac3_ndarray_transpose,
}, },
model::*, model::*,
structs::ndarray::NpArray, structs::{list::List, ndarray::NpArray},
util::{array_writer::ArrayWriter, shape::parse_input_shape_arg}, util::{array_writer::ArrayWriter, shape::parse_input_shape_arg},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
@ -129,3 +130,40 @@ pub fn gen_ndarray_reshape<'ctx>(
let reshaped_ndarray = reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?; let reshaped_ndarray = reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?;
Ok(reshaped_ndarray.value) Ok(reshaped_ndarray.value)
} }
pub fn gen_ndarray_transpose<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1 | 2));
let sizet = generator.get_sizet(context.ctx);
let in_axes_model = PointerModel(StructModel(List { sizet, element: NIntModel(Int32) }));
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?;
// Parse argument #2 axes (optional)
let in_axes = if args.len() == 2 {
let in_shape_ty = fun.0.args[1].ty;
let in_shape_arg =
args[1].1.clone().to_basic_value_enum(context, generator, in_shape_ty)?;
let in_shape = in_axes_model.review_value(context.ctx, in_shape_arg).unwrap();
let num_axes = in_shape.gep(context, |f| f.size).load(context, "num_axes");
let axes = sizet.array_alloca(context, num_axes, "num_axes");
Some((in_shape_ty, in_shape_arg))
} else {
None
};
// call_nac3_ndarray_transpose(generator, ctx, src_ndarray, dst_ndarray, axes_or_none)
todo!()
}

View File

@ -0,0 +1,28 @@
use inkwell::{types::IntType, values::IntValue};
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
/// Convenient structure that looks like
/// Python `range`, with a dependent int type.
pub struct ForRange<T> {
pub start: T,
pub stop: T,
pub step: T,
}
impl<'ctx, T> ForRange<T>
where
T: Model<'ctx, Value = IntValue<'ctx>, Type = IntType<'ctx>>,
{
pub fn end(ctx: &'ctx Context, end: T) -> Self {
todo!()
}
}
pub fn for_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
range: ForRange,
body: BodyFn,
) {
}

View File

@ -1,2 +1,3 @@
pub mod array_writer; pub mod array_writer;
pub mod control_flow;
pub mod shape; pub mod shape;

View File

@ -496,7 +496,9 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye | PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim), PrimDef::FunNpReshape | PrimDef::FunNpTranspose => {
self.build_ndarray_view_functions(prim)
}
PrimDef::FunStr => self.build_str_function(), PrimDef::FunStr => self.build_str_function(),
@ -1337,7 +1339,7 @@ impl<'a> BuiltinBuilder<'a> {
// Build functions related to NDArray views // Build functions related to NDArray views
fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]); debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape, PrimDef::FunNpTranspose]);
match prim { match prim {
PrimDef::FunNpReshape => { PrimDef::FunNpReshape => {
@ -1364,6 +1366,56 @@ impl<'a> BuiltinBuilder<'a> {
}), }),
) )
} }
PrimDef::FunNpTranspose => {
/*
# NDim has to be known (for checking axes's len)
def np_transpose(
array: NDArray[DType, NDim],
axes: Optional[List[int32]] = None,
) -> NDArray[DType, NDim]
*/
// TODO: Allow tuples (or even general iterables in the very far future) on `axes`
let optional_axes_ty = self
.unifier
.subst(
self.primitives.option,
&VarMap::from([(self.option_tvar.id, self.list_int32)]),
)
.unwrap();
TopLevelDef::Function {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "array".into(),
ty: self.primitives.ndarray,
default_value: None,
},
FuncArg {
name: "axes".into(),
ty: optional_axes_ty,
default_value: Some(SymbolValue::OptionNone),
},
],
ret: self.primitives.ndarray,
vars: VarMap::default(),
})),
var_id: vec![self.ndarray_ndims_tvar.id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
numpy_new::view::gen_ndarray_transpose(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum()))
},
)))),
loc: None,
}
}
_ => unreachable!(), _ => unreachable!(),
} }
} }

View File

@ -47,6 +47,7 @@ pub enum PrimDef {
FunNpEye, FunNpEye,
FunNpIdentity, FunNpIdentity,
FunNpReshape, FunNpReshape,
FunNpTranspose,
FunRound, FunRound,
FunRound64, FunRound64,
FunNpRound, FunNpRound,
@ -206,6 +207,7 @@ impl PrimDef {
PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunNpReshape => fun("np_reshape", None), PrimDef::FunNpReshape => fun("np_reshape", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunRound => fun("round", None), PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None), PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),

View File

@ -180,6 +180,7 @@ def patch(module):
# NumPy view functions # NumPy view functions
module.np_reshape = np.reshape module.np_reshape = np.reshape
module.np_transpose = np.transpose
# NumPy Math functions # NumPy Math functions
module.np_isnan = np.isnan module.np_isnan = np.isnan