forked from M-Labs/nac3
core/ndstrides: implement general ndarray reshaping
This commit is contained in:
parent
7372ef0504
commit
5c6537565c
|
@ -0,0 +1,116 @@
|
|||
#pragma once
|
||||
|
||||
#include <irrt/error_context.hpp>
|
||||
#include <irrt/int_defs.hpp>
|
||||
#include <irrt/ndarray/def.hpp>
|
||||
|
||||
namespace {
|
||||
namespace ndarray {
|
||||
namespace reshape {
|
||||
namespace util {
|
||||
|
||||
/**
|
||||
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
|
||||
*
|
||||
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
|
||||
* modified to contain the resolved dimension.
|
||||
*
|
||||
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
|
||||
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
|
||||
*
|
||||
* @param size The `.size` of `<ndarray>`
|
||||
* @param new_ndims Number of elements in `new_shape`
|
||||
* @param new_shape Target shape to reshape to
|
||||
*/
|
||||
template <typename SizeT>
|
||||
void resolve_and_check_new_shape(ErrorContext* errctx, SizeT size,
|
||||
SizeT new_ndims, SizeT* new_shape) {
|
||||
// Is there a -1 in `new_shape`?
|
||||
bool neg1_exists = false;
|
||||
// Location of -1, only initialized if `neg1_exists` is true
|
||||
SizeT neg1_axis_i;
|
||||
// The computed ndarray size of `new_shape`
|
||||
SizeT new_size = 1;
|
||||
|
||||
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
|
||||
SizeT dim = new_shape[axis_i];
|
||||
if (dim < 0) {
|
||||
if (dim == -1) {
|
||||
if (neg1_exists) {
|
||||
// Multiple `-1` found. Throw an error.
|
||||
errctx->set_error(errctx->error_ids->value_error,
|
||||
"can only specify one unknown dimension");
|
||||
return;
|
||||
} else {
|
||||
neg1_exists = true;
|
||||
neg1_axis_i = axis_i;
|
||||
}
|
||||
} else {
|
||||
// TODO: What? In `np.reshape` any negative dimensions is
|
||||
// treated like its `-1`.
|
||||
//
|
||||
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
|
||||
//
|
||||
// It is not documented by numpy.
|
||||
// Throw an error for now...
|
||||
|
||||
errctx->set_error(errctx->error_ids->value_error,
|
||||
"Found negative dimension {0} on axis {1}",
|
||||
dim, axis_i);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
new_size *= dim;
|
||||
}
|
||||
}
|
||||
|
||||
bool can_reshape;
|
||||
if (neg1_exists) {
|
||||
// Let `x` be the unknown dimension
|
||||
// solve `x * <new_size> = <size>`
|
||||
if (new_size == 0 && size == 0) {
|
||||
// `x` has infinitely many solutions
|
||||
can_reshape = false;
|
||||
} else if (new_size == 0 && size != 0) {
|
||||
// `x` has no solutions
|
||||
can_reshape = false;
|
||||
} else if (size % new_size != 0) {
|
||||
// `x` has no integer solutions
|
||||
can_reshape = false;
|
||||
} else {
|
||||
can_reshape = true;
|
||||
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
|
||||
}
|
||||
} else {
|
||||
can_reshape = (new_size == size);
|
||||
}
|
||||
|
||||
if (!can_reshape) {
|
||||
errctx->set_error(errctx->error_ids->value_error,
|
||||
"cannot reshape array of size {0} into given shape",
|
||||
size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
} // namespace util
|
||||
} // namespace reshape
|
||||
} // namespace ndarray
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
void __nac3_ndarray_resolve_and_check_new_shape(ErrorContext* errctx,
|
||||
int32_t size, int32_t new_ndims,
|
||||
int32_t* new_shape) {
|
||||
ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims,
|
||||
new_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_resolve_and_check_new_shape64(ErrorContext* errctx,
|
||||
int64_t size,
|
||||
int64_t new_ndims,
|
||||
int64_t* new_shape) {
|
||||
ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims,
|
||||
new_shape);
|
||||
}
|
||||
}
|
|
@ -8,5 +8,6 @@
|
|||
#include <irrt/ndarray/def.hpp>
|
||||
#include <irrt/ndarray/fill.hpp>
|
||||
#include <irrt/ndarray/indexing.hpp>
|
||||
#include <irrt/ndarray/reshape.hpp>
|
||||
#include <irrt/slice.hpp>
|
||||
#include <irrt/utils.hpp>
|
|
@ -2,3 +2,4 @@ pub mod allocation;
|
|||
pub mod basic;
|
||||
pub mod fill;
|
||||
pub mod indexing;
|
||||
pub mod reshape;
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
use crate::codegen::{
|
||||
irrt::{
|
||||
error_context::{check_error_context, setup_error_context},
|
||||
util::get_sized_dependent_function_name,
|
||||
},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
pub fn call_nac3_ndarray_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
size: SizeT<'ctx>,
|
||||
new_ndims: SizeT<'ctx>,
|
||||
new_shape: Pointer<'ctx, SizeTModel<'ctx>>,
|
||||
) {
|
||||
let sizet = generator.get_sizet(ctx.ctx);
|
||||
|
||||
let perrctx = setup_error_context(ctx);
|
||||
FunctionBuilder::begin(
|
||||
ctx,
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_resolve_and_check_new_shape"),
|
||||
)
|
||||
.arg("errctx", perrctx)
|
||||
.arg("size", size)
|
||||
.arg("new_ndims", new_ndims)
|
||||
.arg("new_shape", new_shape)
|
||||
.returning_void();
|
||||
check_error_context(generator, ctx, perrctx);
|
||||
}
|
|
@ -1 +1,2 @@
|
|||
pub mod factory;
|
||||
pub mod view;
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
use inkwell::values::PointerValue;
|
||||
use nac3parser::ast::StrRef;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::ndarray::{
|
||||
allocation::{alloca_ndarray, init_ndarray_shape},
|
||||
basic::{
|
||||
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes,
|
||||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||
},
|
||||
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
||||
},
|
||||
model::*,
|
||||
structs::ndarray::NpArray,
|
||||
util::{array_writer::ArrayWriter, shape::parse_input_shape_arg},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::DefinitionId,
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
||||
fn reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
new_shape: &ArrayWriter<'ctx, G, SizeTModel<'ctx>, SizeTModel<'ctx>>,
|
||||
) -> Result<Pointer<'ctx, StructModel<NpArray<'ctx>>>, String> {
|
||||
let byte_model = NIntModel(Byte);
|
||||
|
||||
/*
|
||||
Reference pseudo-code:
|
||||
```c
|
||||
NDArray<SizeT>* src_ndarray;
|
||||
|
||||
NDArray<SizeT>* dst_ndarray = __builtin_alloca(...);
|
||||
dst_ndarray->ndims = ...
|
||||
dst_ndarray->strides = __builtin_alloca(...);
|
||||
dst_ndarray->shape = ... // Directly set by user, may contain -1, or even illegal values.
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
set_strides_by_shape(dst_ndarray);
|
||||
|
||||
// Do assertions on `dst_ndarray->shape` and resolve -1
|
||||
|
||||
resolve_and_check_new_shape(ndarray_size(src_ndarray), dst_ndarray->shape);
|
||||
|
||||
if (is_c_contiguous(src_ndarray)) {
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
} else {
|
||||
dst_ndarray->data = __builtin_alloca( ndarray_nbytes(dst_ndarray) );
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
return dst_ndarray;
|
||||
```
|
||||
*/
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then");
|
||||
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||
|
||||
// current_bb
|
||||
let dst_ndarray = alloca_ndarray(generator, ctx, new_shape.count, "ndarray").unwrap();
|
||||
|
||||
init_ndarray_shape(generator, ctx, dst_ndarray, new_shape)?;
|
||||
dst_ndarray
|
||||
.gep(ctx, |f| f.itemsize)
|
||||
.store(ctx, src_ndarray.gep(ctx, |f| f.itemsize).load(ctx, "itemsize"));
|
||||
|
||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray);
|
||||
|
||||
let src_ndarray_size = call_nac3_ndarray_size(generator, ctx, src_ndarray);
|
||||
call_nac3_ndarray_resolve_and_check_new_shape(
|
||||
generator,
|
||||
ctx,
|
||||
src_ndarray_size,
|
||||
dst_ndarray.gep(ctx, |f| f.ndims).load(ctx, "ndims"),
|
||||
dst_ndarray.gep(ctx, |f| f.shape).load(ctx, "shape"),
|
||||
);
|
||||
|
||||
let is_c_contiguous = call_nac3_ndarray_is_c_contiguous(generator, ctx, src_ndarray);
|
||||
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
||||
|
||||
// then_bb: reshape is possible without copying
|
||||
ctx.builder.position_at_end(then_bb);
|
||||
dst_ndarray.gep(ctx, |f| f.data).store(ctx, src_ndarray.gep(ctx, |f| f.data).load(ctx, "data"));
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
|
||||
// else_bb: reshape is impossible without copying
|
||||
ctx.builder.position_at_end(else_bb);
|
||||
let dst_ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, dst_ndarray);
|
||||
let data = byte_model.array_alloca(ctx, dst_ndarray_nbytes, "new_data").pointer;
|
||||
dst_ndarray.gep(ctx, |f| f.data).store(ctx, data);
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
|
||||
// Reposition for continuation
|
||||
ctx.builder.position_at_end(end_bb);
|
||||
|
||||
Ok(dst_ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `np.reshape`.
|
||||
pub fn gen_ndarray_reshape<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 2);
|
||||
|
||||
// Parse argument #1 ndarray
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?;
|
||||
|
||||
// Parse argument #2 shape
|
||||
let shape_ty = fun.0.args[1].ty;
|
||||
let shape_arg = args[1].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
let sizet = generator.get_sizet(context.ctx);
|
||||
let pndarray_model = PointerModel(StructModel(NpArray { sizet }));
|
||||
|
||||
let src_ndarray = pndarray_model.review_value(context.ctx, ndarray_arg).unwrap();
|
||||
let new_shape = parse_input_shape_arg(generator, context, shape_arg, shape_ty);
|
||||
|
||||
let reshaped_ndarray = reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?;
|
||||
Ok(reshaped_ndarray.value)
|
||||
}
|
|
@ -496,6 +496,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
| PrimDef::FunNpEye
|
||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||
|
||||
PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim),
|
||||
|
||||
PrimDef::FunStr => self.build_str_function(),
|
||||
|
||||
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
|
||||
|
@ -1333,6 +1335,39 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
// Build functions related to NDArray views
|
||||
fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]);
|
||||
|
||||
match prim {
|
||||
PrimDef::FunNpReshape => {
|
||||
let new_ndim_ty = self.unifier.get_fresh_var(Some("NewNDim".into()), None);
|
||||
let returned_ndarray_ty = make_ndarray_ty(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(self.ndarray_dtype_tvar.ty),
|
||||
Some(new_ndim_ty.ty),
|
||||
);
|
||||
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&into_var_map([self.ndarray_dtype_tvar, self.ndarray_ndims_tvar, new_ndim_ty]),
|
||||
prim.name(),
|
||||
returned_ndarray_ty,
|
||||
&[
|
||||
(self.primitives.ndarray, "array"),
|
||||
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
|
||||
],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
numpy_new::view::gen_ndarray_reshape(ctx, &obj, fun, &args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the `str()` function.
|
||||
fn build_str_function(&mut self) -> TopLevelDef {
|
||||
let prim = PrimDef::FunStr;
|
||||
|
|
|
@ -46,6 +46,7 @@ pub enum PrimDef {
|
|||
FunNpArray,
|
||||
FunNpEye,
|
||||
FunNpIdentity,
|
||||
FunNpReshape,
|
||||
FunRound,
|
||||
FunRound64,
|
||||
FunNpRound,
|
||||
|
@ -204,6 +205,7 @@ impl PrimDef {
|
|||
PrimDef::FunNpArray => fun("np_array", None),
|
||||
PrimDef::FunNpEye => fun("np_eye", None),
|
||||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||
PrimDef::FunRound => fun("round", None),
|
||||
PrimDef::FunRound64 => fun("round64", None),
|
||||
PrimDef::FunNpRound => fun("np_round", None),
|
||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar237]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar237\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(250)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar236, typevar237]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar236\", \"typevar237\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n",
|
||||
]
|
||||
|
|
|
@ -1390,6 +1390,55 @@ impl<'a> Inferencer<'a> {
|
|||
}));
|
||||
}
|
||||
|
||||
// Handle `np.reshape(<array>, <shape>)`
|
||||
if ["np_reshape".into()].contains(id) && args.len() == 2 {
|
||||
// Extract arguments
|
||||
let array_expr = args.remove(0);
|
||||
let shape_expr = args.remove(0);
|
||||
|
||||
// Fold `<array>`
|
||||
let array = self.fold_expr(array_expr)?;
|
||||
let array_ty = array.custom.unwrap();
|
||||
let (array_dtype, _) = unpack_ndarray_var_tys(self.unifier, array_ty);
|
||||
|
||||
// Fold `<shape>`
|
||||
let (target_ndims, target_shape) =
|
||||
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?;
|
||||
let target_shape_ty = target_shape.custom.unwrap();
|
||||
// ... and deduce the return type of the call
|
||||
let target_ndims_ty =
|
||||
self.unifier.get_fresh_literal(vec![SymbolValue::U64(target_ndims)], None);
|
||||
let ret = make_ndarray_ty(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(array_dtype),
|
||||
Some(target_ndims_ty),
|
||||
);
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "array".into(), ty: array_ty, default_value: None },
|
||||
FuncArg { name: "shape".into(), ty: target_shape_ty, default_value: None },
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||
}),
|
||||
args: vec![array, target_shape],
|
||||
keywords: vec![],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
// 2-argument ndarray n-dimensional creation functions
|
||||
if id == &"np_full".into() && args.len() == 2 {
|
||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||
|
|
|
@ -178,6 +178,9 @@ def patch(module):
|
|||
module.np_identity = np.identity
|
||||
module.np_array = np.array
|
||||
|
||||
# NumPy view functions
|
||||
module.np_reshape = np.reshape
|
||||
|
||||
# NumPy Math functions
|
||||
module.np_isnan = np.isnan
|
||||
module.np_isinf = np.isinf
|
||||
|
|
Loading…
Reference in New Issue