forked from M-Labs/nac3
core/ndstrides: implement general ndarray reshaping
This commit is contained in:
parent
bd5cb14d0d
commit
2747869a45
|
@ -0,0 +1,117 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/error_context.hpp>
|
||||||
|
#include <irrt/int_defs.hpp>
|
||||||
|
#include <irrt/ndarray/def.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
namespace ndarray {
|
||||||
|
namespace reshape {
|
||||||
|
namespace util {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
|
||||||
|
*
|
||||||
|
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
|
||||||
|
* modified to contain the resolved dimension.
|
||||||
|
*
|
||||||
|
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
|
||||||
|
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
|
||||||
|
*
|
||||||
|
* @param size The `.size` of `<ndarray>`
|
||||||
|
* @param new_ndims Number of elements in `new_shape`
|
||||||
|
* @param new_shape Target shape to reshape to
|
||||||
|
*/
|
||||||
|
template <typename SizeT>
|
||||||
|
void resolve_and_check_new_shape(ErrorContext* errctx, SizeT size,
|
||||||
|
SizeT new_ndims, SizeT* new_shape) {
|
||||||
|
// Is there a -1 in `new_shape`?
|
||||||
|
bool neg1_exists = false;
|
||||||
|
// Location of -1, only initialized if `neg1_exists` is true
|
||||||
|
SizeT neg1_axis_i;
|
||||||
|
// The computed ndarray size of `new_shape`
|
||||||
|
SizeT new_size = 1;
|
||||||
|
|
||||||
|
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
|
||||||
|
SizeT dim = new_shape[axis_i];
|
||||||
|
if (dim < 0) {
|
||||||
|
if (dim == -1) {
|
||||||
|
if (neg1_exists) {
|
||||||
|
// Multiple `-1` found. Throw an error.
|
||||||
|
errctx->set_exception(
|
||||||
|
errctx->exceptions->value_error,
|
||||||
|
"can only specify one unknown dimension");
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
neg1_exists = true;
|
||||||
|
neg1_axis_i = axis_i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// TODO: What? In `np.reshape` any negative dimensions is
|
||||||
|
// treated like its `-1`.
|
||||||
|
//
|
||||||
|
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
|
||||||
|
//
|
||||||
|
// It is not documented by numpy.
|
||||||
|
// Throw an error for now...
|
||||||
|
|
||||||
|
errctx->set_exception(
|
||||||
|
errctx->exceptions->value_error,
|
||||||
|
"Found negative dimension {0} on axis {1}", dim, axis_i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
new_size *= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool can_reshape;
|
||||||
|
if (neg1_exists) {
|
||||||
|
// Let `x` be the unknown dimension
|
||||||
|
// solve `x * <new_size> = <size>`
|
||||||
|
if (new_size == 0 && size == 0) {
|
||||||
|
// `x` has infinitely many solutions
|
||||||
|
can_reshape = false;
|
||||||
|
} else if (new_size == 0 && size != 0) {
|
||||||
|
// `x` has no solutions
|
||||||
|
can_reshape = false;
|
||||||
|
} else if (size % new_size != 0) {
|
||||||
|
// `x` has no integer solutions
|
||||||
|
can_reshape = false;
|
||||||
|
} else {
|
||||||
|
can_reshape = true;
|
||||||
|
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
can_reshape = (new_size == size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!can_reshape) {
|
||||||
|
errctx->set_exception(
|
||||||
|
errctx->exceptions->value_error,
|
||||||
|
"cannot reshape array of size {0} into given shape", size);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace util
|
||||||
|
} // namespace reshape
|
||||||
|
} // namespace ndarray
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void __nac3_ndarray_resolve_and_check_new_shape(ErrorContext* errctx,
|
||||||
|
int32_t size, int32_t new_ndims,
|
||||||
|
int32_t* new_shape) {
|
||||||
|
ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims,
|
||||||
|
new_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_resolve_and_check_new_shape64(ErrorContext* errctx,
|
||||||
|
int64_t size,
|
||||||
|
int64_t new_ndims,
|
||||||
|
int64_t* new_shape) {
|
||||||
|
ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims,
|
||||||
|
new_shape);
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,5 +7,6 @@
|
||||||
#include <irrt/ndarray/basic.hpp>
|
#include <irrt/ndarray/basic.hpp>
|
||||||
#include <irrt/ndarray/def.hpp>
|
#include <irrt/ndarray/def.hpp>
|
||||||
#include <irrt/ndarray/indexing.hpp>
|
#include <irrt/ndarray/indexing.hpp>
|
||||||
|
#include <irrt/ndarray/reshape.hpp>
|
||||||
#include <irrt/slice.hpp>
|
#include <irrt/slice.hpp>
|
||||||
#include <irrt/utils.hpp>
|
#include <irrt/utils.hpp>
|
|
@ -1,3 +1,4 @@
|
||||||
pub mod allocation;
|
pub mod allocation;
|
||||||
pub mod basic;
|
pub mod basic;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
|
pub mod reshape;
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
use crate::codegen::{
|
||||||
|
irrt::{
|
||||||
|
error_context::{check_error_context, setup_error_context},
|
||||||
|
util::{function::CallFunction, get_sizet_dependent_function_name},
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn call_nac3_ndarray_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: Int<'ctx, SizeT>,
|
||||||
|
new_ndims: Int<'ctx, SizeT>,
|
||||||
|
new_shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
|
) {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
|
||||||
|
let perrctx = setup_error_context(tyctx, ctx);
|
||||||
|
CallFunction::begin(
|
||||||
|
tyctx,
|
||||||
|
ctx,
|
||||||
|
&get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_resolve_and_check_new_shape"),
|
||||||
|
)
|
||||||
|
.arg("errctx", perrctx)
|
||||||
|
.arg("size", size)
|
||||||
|
.arg("new_ndims", new_ndims)
|
||||||
|
.arg("new_shape", new_shape)
|
||||||
|
.returning_void();
|
||||||
|
check_error_context(generator, ctx, perrctx);
|
||||||
|
}
|
|
@ -1,2 +1,3 @@
|
||||||
pub mod control;
|
pub mod control;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
|
pub mod view;
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
use inkwell::values::PointerValue;
|
||||||
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
irrt::ndarray::{
|
||||||
|
allocation::{alloca_ndarray, init_ndarray_shape},
|
||||||
|
basic::{
|
||||||
|
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes,
|
||||||
|
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||||
|
},
|
||||||
|
reshape::call_nac3_ndarray_resolve_and_check_new_shape,
|
||||||
|
},
|
||||||
|
model::*,
|
||||||
|
structure::ndarray::NpArray,
|
||||||
|
util::{array_writer::ArrayWriter, shape::make_shape_writer},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
symbol_resolver::ValueEnum,
|
||||||
|
toplevel::DefinitionId,
|
||||||
|
typecheck::typedef::{FunSignature, Type},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||||
|
new_shape: &ArrayWriter<'ctx, G, SizeT, IntModel<SizeT>>,
|
||||||
|
) -> Result<Ptr<'ctx, StructModel<NpArray>>, String> {
|
||||||
|
/*
|
||||||
|
Reference pseudo-code:
|
||||||
|
```c
|
||||||
|
NDArray<SizeT>* src_ndarray;
|
||||||
|
|
||||||
|
NDArray<SizeT>* dst_ndarray = __builtin_alloca(...);
|
||||||
|
dst_ndarray->ndims = ...
|
||||||
|
dst_ndarray->strides = __builtin_alloca(...);
|
||||||
|
dst_ndarray->shape = ... // Directly set by user, may contain -1, or even illegal values.
|
||||||
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
set_strides_by_shape(dst_ndarray);
|
||||||
|
|
||||||
|
// Do assertions on `dst_ndarray->shape` and resolve -1
|
||||||
|
|
||||||
|
resolve_and_check_new_shape(ndarray_size(src_ndarray), dst_ndarray->shape);
|
||||||
|
|
||||||
|
if (is_c_contiguous(src_ndarray)) {
|
||||||
|
dst_ndarray->data = src_ndarray->data;
|
||||||
|
} else {
|
||||||
|
dst_ndarray->data = __builtin_alloca( ndarray_nbytes(dst_ndarray) );
|
||||||
|
copy_data(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
|
return dst_ndarray;
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let byte_model = IntModel(Byte);
|
||||||
|
|
||||||
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
|
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then");
|
||||||
|
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
||||||
|
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||||
|
|
||||||
|
// Inserting into current_bb
|
||||||
|
let dst_ndarray = alloca_ndarray(generator, ctx, new_shape.len, "ndarray").unwrap();
|
||||||
|
|
||||||
|
init_ndarray_shape(generator, ctx, dst_ndarray, new_shape)?;
|
||||||
|
dst_ndarray
|
||||||
|
.gep(ctx, |f| f.itemsize)
|
||||||
|
.store(ctx, src_ndarray.gep(ctx, |f| f.itemsize).load(tyctx, ctx, "itemsize"));
|
||||||
|
|
||||||
|
call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray);
|
||||||
|
|
||||||
|
let src_ndarray_size = call_nac3_ndarray_size(generator, ctx, src_ndarray);
|
||||||
|
call_nac3_ndarray_resolve_and_check_new_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
src_ndarray_size,
|
||||||
|
dst_ndarray.gep(ctx, |f| f.ndims).load(tyctx, ctx, "ndims"),
|
||||||
|
dst_ndarray.gep(ctx, |f| f.shape).load(tyctx, ctx, "shape"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let is_c_contiguous = call_nac3_ndarray_is_c_contiguous(generator, ctx, src_ndarray);
|
||||||
|
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
|
||||||
|
|
||||||
|
// Inserting into then_bb: reshape is possible without copying
|
||||||
|
ctx.builder.position_at_end(then_bb);
|
||||||
|
dst_ndarray
|
||||||
|
.gep(ctx, |f| f.data)
|
||||||
|
.store(ctx, src_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "data"));
|
||||||
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
|
// Inserting into else_bb: reshape is impossible without copying
|
||||||
|
ctx.builder.position_at_end(else_bb);
|
||||||
|
let dst_ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, dst_ndarray);
|
||||||
|
let data = byte_model.array_alloca(tyctx, ctx, dst_ndarray_nbytes.value, "new_data");
|
||||||
|
dst_ndarray.gep(ctx, |f| f.data).store(ctx, data);
|
||||||
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
|
// Reposition for continuation
|
||||||
|
ctx.builder.position_at_end(end_bb);
|
||||||
|
|
||||||
|
Ok(dst_ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `np.reshape`.
|
||||||
|
pub fn gen_ndarray_reshape<'ctx>(
|
||||||
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 2);
|
||||||
|
|
||||||
|
// Parse argument #1 ndarray
|
||||||
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
|
let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?;
|
||||||
|
|
||||||
|
// Parse argument #2 shape
|
||||||
|
let shape_ty = fun.0.args[1].ty;
|
||||||
|
let shape_arg = args[1].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
let tyctx = generator.type_context(context.ctx);
|
||||||
|
let pndarray_model = PtrModel(StructModel(NpArray));
|
||||||
|
|
||||||
|
let src_ndarray = pndarray_model.check_value(tyctx, context.ctx, ndarray_arg).unwrap();
|
||||||
|
let new_shape = make_shape_writer(generator, context, shape_arg, shape_ty);
|
||||||
|
|
||||||
|
let reshaped_ndarray =
|
||||||
|
gen_reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?;
|
||||||
|
Ok(reshaped_ndarray.value)
|
||||||
|
}
|
|
@ -496,6 +496,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
| PrimDef::FunNpEye
|
| PrimDef::FunNpEye
|
||||||
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim),
|
||||||
|
|
||||||
PrimDef::FunStr => self.build_str_function(),
|
PrimDef::FunStr => self.build_str_function(),
|
||||||
|
|
||||||
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
|
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
|
||||||
|
@ -1333,6 +1335,39 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build functions related to NDArray views
|
||||||
|
fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]);
|
||||||
|
|
||||||
|
match prim {
|
||||||
|
PrimDef::FunNpReshape => {
|
||||||
|
let new_ndim_ty = self.unifier.get_fresh_var(Some("NewNDim".into()), None);
|
||||||
|
let returned_ndarray_ty = make_ndarray_ty(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(self.ndarray_dtype_tvar.ty),
|
||||||
|
Some(new_ndim_ty.ty),
|
||||||
|
);
|
||||||
|
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&into_var_map([self.ndarray_dtype_tvar, self.ndarray_ndims_tvar, new_ndim_ty]),
|
||||||
|
prim.name(),
|
||||||
|
returned_ndarray_ty,
|
||||||
|
&[
|
||||||
|
(self.primitives.ndarray, "array"),
|
||||||
|
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
|
||||||
|
],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
numpy_new::view::gen_ndarray_reshape(ctx, &obj, fun, &args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build the `str()` function.
|
/// Build the `str()` function.
|
||||||
fn build_str_function(&mut self) -> TopLevelDef {
|
fn build_str_function(&mut self) -> TopLevelDef {
|
||||||
let prim = PrimDef::FunStr;
|
let prim = PrimDef::FunStr;
|
||||||
|
|
|
@ -52,6 +52,9 @@ pub enum PrimDef {
|
||||||
FunNpEye,
|
FunNpEye,
|
||||||
FunNpIdentity,
|
FunNpIdentity,
|
||||||
|
|
||||||
|
// NumPy view functions
|
||||||
|
FunNpReshape,
|
||||||
|
|
||||||
// Miscellaneous NumPy & SciPy functions
|
// Miscellaneous NumPy & SciPy functions
|
||||||
FunNpRound,
|
FunNpRound,
|
||||||
FunNpFloor,
|
FunNpFloor,
|
||||||
|
@ -223,6 +226,9 @@ impl PrimDef {
|
||||||
PrimDef::FunNpEye => fun("np_eye", None),
|
PrimDef::FunNpEye => fun("np_eye", None),
|
||||||
PrimDef::FunNpIdentity => fun("np_identity", None),
|
PrimDef::FunNpIdentity => fun("np_identity", None),
|
||||||
|
|
||||||
|
// NumPy view functions
|
||||||
|
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||||
|
|
||||||
// Miscellaneous NumPy & SciPy functions
|
// Miscellaneous NumPy & SciPy functions
|
||||||
PrimDef::FunNpRound => fun("np_round", None),
|
PrimDef::FunNpRound => fun("np_round", None),
|
||||||
PrimDef::FunNpFloor => fun("np_floor", None),
|
PrimDef::FunNpFloor => fun("np_floor", None),
|
||||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar237]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar237\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(250)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar236, typevar237]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar236\", \"typevar237\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1390,6 +1390,55 @@ impl<'a> Inferencer<'a> {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle `np.reshape(<array>, <shape>)`
|
||||||
|
if ["np_reshape".into()].contains(id) && args.len() == 2 {
|
||||||
|
// Extract arguments
|
||||||
|
let array_expr = args.remove(0);
|
||||||
|
let shape_expr = args.remove(0);
|
||||||
|
|
||||||
|
// Fold `<array>`
|
||||||
|
let array = self.fold_expr(array_expr)?;
|
||||||
|
let array_ty = array.custom.unwrap();
|
||||||
|
let (array_dtype, _) = unpack_ndarray_var_tys(self.unifier, array_ty);
|
||||||
|
|
||||||
|
// Fold `<shape>`
|
||||||
|
let (target_ndims, target_shape) =
|
||||||
|
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?;
|
||||||
|
let target_shape_ty = target_shape.custom.unwrap();
|
||||||
|
// ... and deduce the return type of the call
|
||||||
|
let target_ndims_ty =
|
||||||
|
self.unifier.get_fresh_literal(vec![SymbolValue::U64(target_ndims)], None);
|
||||||
|
let ret = make_ndarray_ty(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(array_dtype),
|
||||||
|
Some(target_ndims_ty),
|
||||||
|
);
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "array".into(), ty: array_ty, default_value: None },
|
||||||
|
FuncArg { name: "shape".into(), ty: target_shape_ty, default_value: None },
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![array, target_shape],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
// 2-argument ndarray n-dimensional creation functions
|
// 2-argument ndarray n-dimensional creation functions
|
||||||
if id == &"np_full".into() && args.len() == 2 {
|
if id == &"np_full".into() && args.len() == 2 {
|
||||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||||
|
|
|
@ -178,6 +178,9 @@ def patch(module):
|
||||||
module.np_identity = np.identity
|
module.np_identity = np.identity
|
||||||
module.np_array = np.array
|
module.np_array = np.array
|
||||||
|
|
||||||
|
# NumPy view functions
|
||||||
|
module.np_reshape = np.reshape
|
||||||
|
|
||||||
# NumPy Math functions
|
# NumPy Math functions
|
||||||
module.np_isnan = np.isnan
|
module.np_isnan = np.isnan
|
||||||
module.np_isinf = np.isinf
|
module.np_isinf = np.isinf
|
||||||
|
|
Loading…
Reference in New Issue