diff --git a/nac3core/irrt/irrt/ndarray/reshape.hpp b/nac3core/irrt/irrt/ndarray/reshape.hpp new file mode 100644 index 00000000..cd873436 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/reshape.hpp @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include + +namespace { +namespace ndarray { +namespace reshape { +namespace util { + +/** + * @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(, new_shape)` + * + * If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be + * modified to contain the resolved dimension. + * + * To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual + * `` object itself, but only the `.size` of the ``. + * + * @param size The `.size` of `` + * @param new_ndims Number of elements in `new_shape` + * @param new_shape Target shape to reshape to + */ +template +void resolve_and_check_new_shape(ErrorContext* errctx, SizeT size, + SizeT new_ndims, SizeT* new_shape) { + // Is there a -1 in `new_shape`? + bool neg1_exists = false; + // Location of -1, only initialized if `neg1_exists` is true + SizeT neg1_axis_i; + // The computed ndarray size of `new_shape` + SizeT new_size = 1; + + for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) { + SizeT dim = new_shape[axis_i]; + if (dim < 0) { + if (dim == -1) { + if (neg1_exists) { + // Multiple `-1` found. Throw an error. + errctx->set_exception( + errctx->exceptions->value_error, + "can only specify one unknown dimension"); + return; + } else { + neg1_exists = true; + neg1_axis_i = axis_i; + } + } else { + // TODO: What? In `np.reshape` any negative dimensions is + // treated like its `-1`. + // + // Try running `np.zeros((3, 4)).reshape((-999, 2))` + // + // It is not documented by numpy. + // Throw an error for now... + + errctx->set_exception( + errctx->exceptions->value_error, + "Found negative dimension {0} on axis {1}", dim, axis_i); + return; + } + } else { + new_size *= dim; + } + } + + bool can_reshape; + if (neg1_exists) { + // Let `x` be the unknown dimension + // solve `x * = ` + if (new_size == 0 && size == 0) { + // `x` has infinitely many solutions + can_reshape = false; + } else if (new_size == 0 && size != 0) { + // `x` has no solutions + can_reshape = false; + } else if (size % new_size != 0) { + // `x` has no integer solutions + can_reshape = false; + } else { + can_reshape = true; + new_shape[neg1_axis_i] = size / new_size; // Resolve dimension + } + } else { + can_reshape = (new_size == size); + } + + if (!can_reshape) { + errctx->set_exception( + errctx->exceptions->value_error, + "cannot reshape array of size {0} into given shape", size); + return; + } +} +} // namespace util +} // namespace reshape +} // namespace ndarray +} // namespace + +extern "C" { + +void __nac3_ndarray_resolve_and_check_new_shape(ErrorContext* errctx, + int32_t size, int32_t new_ndims, + int32_t* new_shape) { + ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims, + new_shape); +} + +void __nac3_ndarray_resolve_and_check_new_shape64(ErrorContext* errctx, + int64_t size, + int64_t new_ndims, + int64_t* new_shape) { + ndarray::reshape::util::resolve_and_check_new_shape(errctx, size, new_ndims, + new_shape); +} +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index feb2d9e1..1608b861 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -7,5 +7,6 @@ #include #include #include +#include #include #include \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/ndarray/mod.rs b/nac3core/src/codegen/irrt/ndarray/mod.rs index 83a88f08..6498e1c2 100644 --- a/nac3core/src/codegen/irrt/ndarray/mod.rs +++ b/nac3core/src/codegen/irrt/ndarray/mod.rs @@ -1,3 +1,4 @@ pub mod allocation; pub mod basic; pub mod indexing; +pub mod reshape; diff --git a/nac3core/src/codegen/irrt/ndarray/reshape.rs b/nac3core/src/codegen/irrt/ndarray/reshape.rs new file mode 100644 index 00000000..114f4cee --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray/reshape.rs @@ -0,0 +1,31 @@ +use crate::codegen::{ + irrt::{ + error_context::{check_error_context, setup_error_context}, + util::{function::CallFunction, get_sizet_dependent_function_name}, + }, + model::*, + CodeGenContext, CodeGenerator, +}; + +pub fn call_nac3_ndarray_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: Int<'ctx, SizeT>, + new_ndims: Int<'ctx, SizeT>, + new_shape: Ptr<'ctx, IntModel>, +) { + let tyctx = generator.type_context(ctx.ctx); + + let perrctx = setup_error_context(tyctx, ctx); + CallFunction::begin( + tyctx, + ctx, + &get_sizet_dependent_function_name(tyctx, "__nac3_ndarray_resolve_and_check_new_shape"), + ) + .arg("errctx", perrctx) + .arg("size", size) + .arg("new_ndims", new_ndims) + .arg("new_shape", new_shape) + .returning_void(); + check_error_context(generator, ctx, perrctx); +} diff --git a/nac3core/src/codegen/numpy_new/mod.rs b/nac3core/src/codegen/numpy_new/mod.rs index 0f1a26b8..96c68286 100644 --- a/nac3core/src/codegen/numpy_new/mod.rs +++ b/nac3core/src/codegen/numpy_new/mod.rs @@ -1,2 +1,3 @@ pub mod control; pub mod factory; +pub mod view; diff --git a/nac3core/src/codegen/numpy_new/view.rs b/nac3core/src/codegen/numpy_new/view.rs new file mode 100644 index 00000000..805e3568 --- /dev/null +++ b/nac3core/src/codegen/numpy_new/view.rs @@ -0,0 +1,135 @@ +use inkwell::values::PointerValue; +use nac3parser::ast::StrRef; + +use crate::{ + codegen::{ + irrt::ndarray::{ + allocation::{alloca_ndarray, init_ndarray_shape}, + basic::{ + call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_nbytes, + call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, + }, + reshape::call_nac3_ndarray_resolve_and_check_new_shape, + }, + model::*, + structure::ndarray::NpArray, + util::{array_writer::ArrayWriter, shape::make_shape_writer}, + CodeGenContext, CodeGenerator, + }, + symbol_resolver::ValueEnum, + toplevel::DefinitionId, + typecheck::typedef::{FunSignature, Type}, +}; + +fn gen_reshape_ndarray_or_copy<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Ptr<'ctx, StructModel>, + new_shape: &ArrayWriter<'ctx, G, SizeT, IntModel>, +) -> Result>, String> { + /* + Reference pseudo-code: + ```c + NDArray* src_ndarray; + + NDArray* dst_ndarray = __builtin_alloca(...); + dst_ndarray->ndims = ... + dst_ndarray->strides = __builtin_alloca(...); + dst_ndarray->shape = ... // Directly set by user, may contain -1, or even illegal values. + dst_ndarray->itemsize = src_ndarray->itemsize; + set_strides_by_shape(dst_ndarray); + + // Do assertions on `dst_ndarray->shape` and resolve -1 + + resolve_and_check_new_shape(ndarray_size(src_ndarray), dst_ndarray->shape); + + if (is_c_contiguous(src_ndarray)) { + dst_ndarray->data = src_ndarray->data; + } else { + dst_ndarray->data = __builtin_alloca( ndarray_nbytes(dst_ndarray) ); + copy_data(src_ndarray, dst_ndarray); + } + + return dst_ndarray; + ``` + */ + + let tyctx = generator.type_context(ctx.ctx); + let byte_model = IntModel(Byte); + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then"); + let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb"); + let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); + + // Inserting into current_bb + let dst_ndarray = alloca_ndarray(generator, ctx, new_shape.len, "ndarray").unwrap(); + + init_ndarray_shape(generator, ctx, dst_ndarray, new_shape)?; + dst_ndarray + .gep(ctx, |f| f.itemsize) + .store(ctx, src_ndarray.gep(ctx, |f| f.itemsize).load(tyctx, ctx, "itemsize")); + + call_nac3_ndarray_set_strides_by_shape(generator, ctx, dst_ndarray); + + let src_ndarray_size = call_nac3_ndarray_size(generator, ctx, src_ndarray); + call_nac3_ndarray_resolve_and_check_new_shape( + generator, + ctx, + src_ndarray_size, + dst_ndarray.gep(ctx, |f| f.ndims).load(tyctx, ctx, "ndims"), + dst_ndarray.gep(ctx, |f| f.shape).load(tyctx, ctx, "shape"), + ); + + let is_c_contiguous = call_nac3_ndarray_is_c_contiguous(generator, ctx, src_ndarray); + ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap(); + + // Inserting into then_bb: reshape is possible without copying + ctx.builder.position_at_end(then_bb); + dst_ndarray + .gep(ctx, |f| f.data) + .store(ctx, src_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "data")); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Inserting into else_bb: reshape is impossible without copying + ctx.builder.position_at_end(else_bb); + let dst_ndarray_nbytes = call_nac3_ndarray_nbytes(generator, ctx, dst_ndarray); + let data = byte_model.array_alloca(tyctx, ctx, dst_ndarray_nbytes.value, "new_data"); + dst_ndarray.gep(ctx, |f| f.data).store(ctx, data); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Reposition for continuation + ctx.builder.position_at_end(end_bb); + + Ok(dst_ndarray) +} + +/// Generates LLVM IR for `np.reshape`. +pub fn gen_ndarray_reshape<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 2); + + // Parse argument #1 ndarray + let ndarray_ty = fun.0.args[0].ty; + let ndarray_arg = args[0].1.clone().to_basic_value_enum(context, generator, ndarray_ty)?; + + // Parse argument #2 shape + let shape_ty = fun.0.args[1].ty; + let shape_arg = args[1].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + let tyctx = generator.type_context(context.ctx); + let pndarray_model = PtrModel(StructModel(NpArray)); + + let src_ndarray = pndarray_model.check_value(tyctx, context.ctx, ndarray_arg).unwrap(); + let new_shape = make_shape_writer(generator, context, shape_arg, shape_ty); + + let reshaped_ndarray = + gen_reshape_ndarray_or_copy(generator, context, src_ndarray, &new_shape)?; + Ok(reshaped_ndarray.value) +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index eb9ed31b..a53aa15b 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -496,6 +496,8 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpReshape => self.build_ndarray_view_functions(prim), + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -1333,6 +1335,39 @@ impl<'a> BuiltinBuilder<'a> { } } + // Build functions related to NDArray views + fn build_ndarray_view_functions(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpReshape]); + + match prim { + PrimDef::FunNpReshape => { + let new_ndim_ty = self.unifier.get_fresh_var(Some("NewNDim".into()), None); + let returned_ndarray_ty = make_ndarray_ty( + self.unifier, + self.primitives, + Some(self.ndarray_dtype_tvar.ty), + Some(new_ndim_ty.ty), + ); + + create_fn_by_codegen( + self.unifier, + &into_var_map([self.ndarray_dtype_tvar, self.ndarray_ndims_tvar, new_ndim_ty]), + prim.name(), + returned_ndarray_ty, + &[ + (self.primitives.ndarray, "array"), + (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), + ], + Box::new(|ctx, obj, fun, args, generator| { + numpy_new::view::gen_ndarray_reshape(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ) + } + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index c4a69638..206060d9 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -52,6 +52,9 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy view functions + FunNpReshape, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -223,6 +226,9 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy view functions + PrimDef::FunNpReshape => fun("np_reshape", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index c8ff7dba..a04424ef 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(248)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index b67596d8..284fdc12 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar237]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar237\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 08f254f5..7228fc3d 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(250)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index ce3b02ed..c4216511 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar236, typevar237]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar236\", \"typevar237\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index ae002764..8e6e2d3d 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(256)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index d9380ab1..06d473ff 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1390,6 +1390,55 @@ impl<'a> Inferencer<'a> { })); } + // Handle `np.reshape(, )` + if ["np_reshape".into()].contains(id) && args.len() == 2 { + // Extract arguments + let array_expr = args.remove(0); + let shape_expr = args.remove(0); + + // Fold `` + let array = self.fold_expr(array_expr)?; + let array_ty = array.custom.unwrap(); + let (array_dtype, _) = unpack_ndarray_var_tys(self.unifier, array_ty); + + // Fold `` + let (target_ndims, target_shape) = + self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; + let target_shape_ty = target_shape.custom.unwrap(); + // ... and deduce the return type of the call + let target_ndims_ty = + self.unifier.get_fresh_literal(vec![SymbolValue::U64(target_ndims)], None); + let ret = make_ndarray_ty( + self.unifier, + self.primitives, + Some(array_dtype), + Some(target_ndims_ty), + ); + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "array".into(), ty: array_ty, default_value: None }, + FuncArg { name: "shape".into(), ty: target_shape_ty, default_value: None }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![array, target_shape], + keywords: vec![], + }, + })); + } + // 2-argument ndarray n-dimensional creation functions if id == &"np_full".into() && args.len() == 2 { let ExprKind::List { elts, .. } = &args[0].node else { diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index b948edee..4267aa66 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -178,6 +178,9 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy view functions + module.np_reshape = np.reshape + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf