forked from M-Labs/nac3
This commit is contained in:
parent
86ed0140cb
commit
819e1e4608
|
@ -151,7 +151,7 @@
|
||||||
buildInputs = with pkgs; [
|
buildInputs = with pkgs; [
|
||||||
# build dependencies
|
# build dependencies
|
||||||
packages.x86_64-linux.llvm-nac3
|
packages.x86_64-linux.llvm-nac3
|
||||||
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
|
llvmPackages_14.clang llvmPackages_14.llvm.out llvmPackages_14.lldb.out # for running nac3standalone demos
|
||||||
packages.x86_64-linux.llvm-tools-irrt
|
packages.x86_64-linux.llvm-tools-irrt
|
||||||
cargo
|
cargo
|
||||||
rustc
|
rustc
|
||||||
|
@ -163,7 +163,9 @@
|
||||||
clippy
|
clippy
|
||||||
pre-commit
|
pre-commit
|
||||||
rustfmt
|
rustfmt
|
||||||
|
rust-analyzer
|
||||||
];
|
];
|
||||||
|
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||||
};
|
};
|
||||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||||
name = "nac3-dev-shell-msys2";
|
name = "nac3-dev-shell-msys2";
|
||||||
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
#include <irrt/slice.hpp>
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Notes on `np.transpose(<array>, <axes>)`
|
||||||
|
*
|
||||||
|
* TODO: `axes`, if specified, can actually contain negative indices,
|
||||||
|
* but it is not documented in numpy.
|
||||||
|
*
|
||||||
|
* Supporting it for now.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace transpose {
|
||||||
|
namespace util {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
|
||||||
|
*
|
||||||
|
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
|
||||||
|
* is specified but the user, use this function to do assertions on it.
|
||||||
|
*
|
||||||
|
* @param ndims The number of dimensions of `<array>`
|
||||||
|
* @param num_axes Number of elements in `<axes>` as specified by the user.
|
||||||
|
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
|
||||||
|
* @param axes The user specified `<axes>`.
|
||||||
|
*/
|
||||||
|
template <typename SizeT>
|
||||||
|
void assert_transpose_axes(ErrorContext* errctx, SizeT ndims, SizeT num_axes,
|
||||||
|
const SizeT* axes) {
|
||||||
|
/*
|
||||||
|
* TODO: `axes` can actually contain negative indices, but it is not documented in numpy.
|
||||||
|
*
|
||||||
|
* Supporting it for now.
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (ndims != num_axes) {
|
||||||
|
errctx->set_error(errctx->error_ids->value_error,
|
||||||
|
"axes don't match array");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Optimize this
|
||||||
|
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
|
||||||
|
for (SizeT i = 0; i < ndims; i++) axe_specified[i] = false;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < ndims; i++) {
|
||||||
|
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
||||||
|
if (axis == slice::OUT_OF_BOUNDS) {
|
||||||
|
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"axis {0} is out of bounds for array of dimension {1}", axis,
|
||||||
|
ndims);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (axe_specified[axis]) {
|
||||||
|
errctx->set_error(errctx->error_ids->value_error,
|
||||||
|
"repeated axis in transpose");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
axe_specified[axis] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace util
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
|
||||||
|
*
|
||||||
|
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
|
||||||
|
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
|
||||||
|
*
|
||||||
|
* The transpose view created is returned by modifying `dst_ndarray`.
|
||||||
|
*
|
||||||
|
* The caller is responsible for setting up `dst_ndarray` before calling this function.
|
||||||
|
* Here is what this function expects from `dst_ndarray` when called:
|
||||||
|
* - `dst_ndarray->data` does not have to be initialized.
|
||||||
|
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||||
|
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
|
||||||
|
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
||||||
|
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||||
|
* When this function call ends:
|
||||||
|
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||||
|
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||||
|
* - `dst_ndarray->ndims` is unchanged
|
||||||
|
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
|
||||||
|
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
|
||||||
|
*
|
||||||
|
* @param src_ndarray The NDArray to build a transpose view on
|
||||||
|
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
|
||||||
|
* @param num_axes Number of elements in axes, can be undefined if `axes` is nullptr.
|
||||||
|
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is supposed to be `None`.
|
||||||
|
*/
|
||||||
|
template <typename SizeT>
|
||||||
|
void transpose(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray,
|
||||||
|
NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
|
||||||
|
__builtin_assume(src_ndarray->ndims == dst_ndarray->ndims);
|
||||||
|
const auto ndims = src_ndarray->ndims;
|
||||||
|
|
||||||
|
if (axes != nullptr) {
|
||||||
|
util::assert_transpose_axes(errctx, ndims, num_axes, axes);
|
||||||
|
if (errctx->has_error()) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_ndarray->data = src_ndarray->data;
|
||||||
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
|
||||||
|
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
|
||||||
|
if (axes == nullptr) {
|
||||||
|
// `np.transpose(<array>, axes=None)`
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
|
||||||
|
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
|
||||||
|
* is reversing the order of strides and shape.
|
||||||
|
*
|
||||||
|
* This is a fast implementation to handle this special (but very common) case.
|
||||||
|
*/
|
||||||
|
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
|
||||||
|
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// `np.transpose(<array>, <axes>)`
|
||||||
|
|
||||||
|
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
|
||||||
|
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||||
|
// `i` cannot be OUT_OF_BOUNDS because of assertions
|
||||||
|
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
|
||||||
|
|
||||||
|
dst_ndarray->shape[axis] = src_ndarray->shape[i];
|
||||||
|
dst_ndarray->strides[axis] = src_ndarray->strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace transpose
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
using namespace ndarray::transpose;
|
||||||
|
void __nac3_ndarray_transpose(ErrorContext* errctx,
|
||||||
|
const NDArray<int32_t>* src_ndarray,
|
||||||
|
NDArray<int32_t>* dst_ndarray, int32_t num_axes,
|
||||||
|
const int32_t* axes) {
|
||||||
|
transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_transpose64(ErrorContext* errctx,
|
||||||
|
const NDArray<int64_t>* src_ndarray,
|
||||||
|
NDArray<int64_t>* dst_ndarray, int64_t num_axes,
|
||||||
|
const int64_t* axes) {
|
||||||
|
transpose(errctx, src_ndarray, dst_ndarray, num_axes, axes);
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,5 +9,6 @@
|
||||||
#include <irrt/ndarray/fill.hpp>
|
#include <irrt/ndarray/fill.hpp>
|
||||||
#include <irrt/ndarray/indexing.hpp>
|
#include <irrt/ndarray/indexing.hpp>
|
||||||
#include <irrt/ndarray/reshape.hpp>
|
#include <irrt/ndarray/reshape.hpp>
|
||||||
|
#include <irrt/ndarray/transpose.hpp>
|
||||||
#include <irrt/slice.hpp>
|
#include <irrt/slice.hpp>
|
||||||
#include <irrt/utils.hpp>
|
#include <irrt/utils.hpp>
|
|
@ -3,3 +3,4 @@ pub mod basic;
|
||||||
pub mod fill;
|
pub mod fill;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod reshape;
|
pub mod reshape;
|
||||||
|
pub mod transpose;
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::{
|
||||||
|
error_context::{check_error_context, setup_error_context},
|
||||||
|
util::get_sized_dependent_function_name,
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
structs::ndarray::NpArray,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||||
|
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||||
|
axes_or_none: Option<ArraySlice<'ctx, SizeTModel<'ctx>, SizeTModel<'ctx>>>,
|
||||||
|
) -> Pointer<'ctx, StructModel<NpArray<'ctx>>> {
|
||||||
|
let sizet = generator.get_sizet(ctx.ctx);
|
||||||
|
let axes_model = PointerModel(sizet);
|
||||||
|
|
||||||
|
let (num_axes, axes) = match axes_or_none {
|
||||||
|
Some(axes) => (axes.num_elements, axes.pointer),
|
||||||
|
None => {
|
||||||
|
// Please refer to the comment in the IRRT implementation
|
||||||
|
(sizet.constant(ctx.ctx, 0), axes_model.nullptr(ctx.ctx))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let perrctx = setup_error_context(ctx);
|
||||||
|
FunctionBuilder::begin(
|
||||||
|
ctx,
|
||||||
|
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_transpose"),
|
||||||
|
)
|
||||||
|
.arg("errctx", perrctx)
|
||||||
|
.arg("src_ndarray", src_ndarray)
|
||||||
|
.arg("dst_ndarray", dst_ndarray)
|
||||||
|
.arg("num_axes", num_axes)
|
||||||
|
.arg("axes", axes)
|
||||||
|
.returning_void();
|
||||||
|
check_error_context(generator, ctx, perrctx);
|
||||||
|
|
||||||
|
dst_ndarray
|
||||||
|
}
|
|
@ -10,9 +10,10 @@ use crate::{
|
||||||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||||
},
|
},
|
||||||
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
||||||
|
transpose::call_nac3_ndarray_transpose,
|
||||||
},
|
},
|
||||||
model::*,
|
model::*,
|
||||||
structs::ndarray::NpArray,
|
structs::{list::List, ndarray::NpArray},
|
||||||
util::{array_writer::ArrayWriter, shape::parse_input_shape_arg},
|
util::{array_writer::ArrayWriter, shape::parse_input_shape_arg},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
|
@ -129,3 +130,40 @@ pub fn gen_ndarray_reshape<'ctx>(
|
||||||
let reshaped_ndarray = reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?;
|
let reshaped_ndarray = reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?;
|
||||||
Ok(reshaped_ndarray.value)
|
Ok(reshaped_ndarray.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn gen_ndarray_transpose<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert!(matches!(args.len(), 1 | 2));
|
||||||
|
|
||||||
|
let sizet = generator.get_sizet(context.ctx);
|
||||||
|
let in_axes_model = PointerModel(StructModel(List { sizet, element: NIntModel(Int32) }));
|
||||||
|
|
||||||
|
// Parse argument #1 ndarray
|
||||||
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
|
let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?;
|
||||||
|
|
||||||
|
// Parse argument #2 axes (optional)
|
||||||
|
let in_axes = if args.len() == 2 {
|
||||||
|
let in_shape_ty = fun.0.args[1].ty;
|
||||||
|
let in_shape_arg =
|
||||||
|
args[1].1.clone().to_basic_value_enum(context, generator, in_shape_ty)?;
|
||||||
|
|
||||||
|
let in_shape = in_axes_model.review_value(context.ctx, in_shape_arg).unwrap();
|
||||||
|
|
||||||
|
let num_axes = in_shape.gep(context, |f| f.size).load(context, "num_axes");
|
||||||
|
let axes = sizet.array_alloca(context, num_axes, "num_axes");
|
||||||
|
|
||||||
|
Some((in_shape_ty, in_shape_arg))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
// call_nac3_ndarray_transpose(generator, ctx, src_ndarray, dst_ndarray, axes_or_none)
|
||||||
|
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
use inkwell::{types::IntType, values::IntValue};
|
||||||
|
|
||||||
|
use crate::codegen::{model::*, CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
|
/// Convenient structure that looks like
|
||||||
|
/// Python `range`, with a dependent int type.
|
||||||
|
pub struct ForRange<T> {
|
||||||
|
pub start: T,
|
||||||
|
pub stop: T,
|
||||||
|
pub step: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, T> ForRange<T>
|
||||||
|
where
|
||||||
|
T: Model<'ctx, Value = IntValue<'ctx>, Type = IntType<'ctx>>,
|
||||||
|
{
|
||||||
|
pub fn end(ctx: &'ctx Context, end: T) -> Self {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn for_range<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
range: ForRange,
|
||||||
|
body: BodyFn,
|
||||||
|
) {
|
||||||
|
}
|
|
@ -1,2 +1,3 @@
|
||||||
pub mod array_writer;
|
pub mod array_writer;
|
||||||
|
pub mod control_flow;
|
||||||
pub mod shape;
|
pub mod shape;
|
||||||
|
|
|
@ -496,7 +496,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
| PrimDef::FunNpEye
|
| PrimDef::FunNpEye
|
||||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim),
|
PrimDef::FunNpReshape | PrimDef::FunNpTranspose => {
|
||||||
|
self.build_ndarray_view_functions(prim)
|
||||||
|
}
|
||||||
|
|
||||||
PrimDef::FunStr => self.build_str_function(),
|
PrimDef::FunStr => self.build_str_function(),
|
||||||
|
|
||||||
|
@ -1337,7 +1339,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
// Build functions related to NDArray views
|
// Build functions related to NDArray views
|
||||||
fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]);
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape, PrimDef::FunNpTranspose]);
|
||||||
|
|
||||||
match prim {
|
match prim {
|
||||||
PrimDef::FunNpReshape => {
|
PrimDef::FunNpReshape => {
|
||||||
|
@ -1364,6 +1366,56 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
PrimDef::FunNpTranspose => {
|
||||||
|
/*
|
||||||
|
# NDim has to be known (for checking axes's len)
|
||||||
|
def np_transpose(
|
||||||
|
array: NDArray[DType, NDim],
|
||||||
|
axes: Optional[List[int32]] = None,
|
||||||
|
) -> NDArray[DType, NDim]
|
||||||
|
*/
|
||||||
|
|
||||||
|
// TODO: Allow tuples (or even general iterables in the very far future) on `axes`
|
||||||
|
let optional_axes_ty = self
|
||||||
|
.unifier
|
||||||
|
.subst(
|
||||||
|
self.primitives.option,
|
||||||
|
&VarMap::from([(self.option_tvar.id, self.list_int32)]),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
TopLevelDef::Function {
|
||||||
|
name: prim.name().into(),
|
||||||
|
simple_name: prim.simple_name().into(),
|
||||||
|
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "array".into(),
|
||||||
|
ty: self.primitives.ndarray,
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
FuncArg {
|
||||||
|
name: "axes".into(),
|
||||||
|
ty: optional_axes_ty,
|
||||||
|
default_value: Some(SymbolValue::OptionNone),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret: self.primitives.ndarray,
|
||||||
|
vars: VarMap::default(),
|
||||||
|
})),
|
||||||
|
var_id: vec![self.ndarray_ndims_tvar.id],
|
||||||
|
instance_to_symbol: HashMap::default(),
|
||||||
|
instance_to_stmt: HashMap::default(),
|
||||||
|
resolver: None,
|
||||||
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|
|ctx, obj, fun, args, generator| {
|
||||||
|
numpy_new::view::gen_ndarray_transpose(ctx, &obj, fun, &args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
},
|
||||||
|
)))),
|
||||||
|
loc: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,7 @@ pub enum PrimDef {
|
||||||
FunNpEye,
|
FunNpEye,
|
||||||
FunNpIdentity,
|
FunNpIdentity,
|
||||||
FunNpReshape,
|
FunNpReshape,
|
||||||
|
FunNpTranspose,
|
||||||
FunRound,
|
FunRound,
|
||||||
FunRound64,
|
FunRound64,
|
||||||
FunNpRound,
|
FunNpRound,
|
||||||
|
@ -206,6 +207,7 @@ impl PrimDef {
|
||||||
PrimDef::FunNpEye => fun("np_eye", None),
|
PrimDef::FunNpEye => fun("np_eye", None),
|
||||||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||||
|
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||||
PrimDef::FunRound => fun("round", None),
|
PrimDef::FunRound => fun("round", None),
|
||||||
PrimDef::FunRound64 => fun("round64", None),
|
PrimDef::FunRound64 => fun("round64", None),
|
||||||
PrimDef::FunNpRound => fun("np_round", None),
|
PrimDef::FunNpRound => fun("np_round", None),
|
||||||
|
|
|
@ -180,6 +180,7 @@ def patch(module):
|
||||||
|
|
||||||
# NumPy view functions
|
# NumPy view functions
|
||||||
module.np_reshape = np.reshape
|
module.np_reshape = np.reshape
|
||||||
|
module.np_transpose = np.transpose
|
||||||
|
|
||||||
# NumPy Math functions
|
# NumPy Math functions
|
||||||
module.np_isnan = np.isnan
|
module.np_isnan = np.isnan
|
||||||
|
|
Loading…
Reference in New Issue