From 00236f48bc57c112566edaea4ca36063bb647c7b Mon Sep 17 00:00:00 2001 From: abdul124 Date: Wed, 31 Jul 2024 13:16:42 +0800 Subject: [PATCH] core: add np.transpose and np.reshape functions --- nac3core/src/codegen/numpy.rs | 391 ++++++++++++++++++ nac3core/src/toplevel/builtins.rs | 55 +++ nac3core/src/toplevel/helper.rs | 6 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 38 ++ 9 files changed, 497 insertions(+), 7 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index a2b3c2c1..4ab1391e 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -2026,3 +2026,394 @@ pub fn gen_ndarray_fill<'ctx>( Ok(()) } + +/// Generates LLVM IR for `ndarray.transpose`. +pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "ndarray_transpose"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + + // Dimensions are reversed in the transposed array + let out = create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &n1, + |_, ctx, n| Ok(n.load_ndims(ctx)), + |generator, ctx, n, idx| { + let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap(); + let new_idx = ctx + .builder + .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") + .unwrap(); + unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } + }, + ) + .unwrap(); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (n_sz, false), + |generator, ctx, _, idx| { + let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; + + let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); + ctx.builder.build_store(rem_idx, idx).unwrap(); + + // Incrementally calculate the new index in the transposed array + // For each index, we first decompose it into the n-dims and use those to reconstruct the new index + // The formula used for indexing is: + // idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (n1.load_ndims(ctx), false), + |generator, ctx, _, ndim| { + let ndim_rev = + ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap(); + let ndim_rev = ctx + .builder + .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") + .unwrap(); + let dim = unsafe { + n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) + }; + + let rem_idx_val = + ctx.builder.build_load(rem_idx, "").unwrap().into_int_value(); + let new_idx_val = + ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); + + let add_component = + ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap(); + let rem_idx_val = + ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap(); + + let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap(); + let new_idx_val = + ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap(); + + ctx.builder.build_store(rem_idx, rem_idx_val).unwrap(); + ctx.builder.build_store(new_idx, new_idx_val).unwrap(); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); + unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) }; + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + Ok(out.as_base_value().into()) + } else { + unreachable!( + "{FN_NAME}() not supported for '{}'", + format!("'{}'", ctx.unifier.stringify(x1_ty)) + ) + } +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`. +/// +/// * `x1` - `NDArray` to reshape. +/// * `shape` - The `shape` parameter used to construct the new `NDArray`. +/// Just like numpy, the `shape` argument can be: +/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` +/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` +/// Note that unlike other generating functions, one of the dimesions in the shape can be negative +pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + shape: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "ndarray_reshape"; + let (x1_ty, x1) = x1; + let (_, shape) = shape; + + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + + let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap(); + ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap(); + + let out = match shape { + BasicValueEnum::PointerValue(shape_list_ptr) + if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => + { + // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` + + let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); + // Check for -1 in dimensions + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (shape_list.load_size(ctx, None), false), + |generator, ctx, _, idx| { + let ele = + shape_list.data().get(ctx, generator, &idx, None).into_int_value(); + let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap(); + + gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare( + IntPredicate::SLT, + ele, + llvm_usize.const_zero(), + "", + ) + .unwrap()) + }, + |_, ctx| -> Result, String> { + let num_neg_value = + ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); + let num_neg_value = ctx + .builder + .build_int_add( + num_neg_value, + llvm_usize.const_int(1, false), + "", + ) + .unwrap(); + ctx.builder.build_store(num_neg, num_neg_value).unwrap(); + Ok(None) + }, + |_, ctx| { + let acc_value = + ctx.builder.build_load(acc, "").unwrap().into_int_value(); + let acc_value = + ctx.builder.build_int_mul(acc_value, ele, "").unwrap(); + ctx.builder.build_store(acc, acc_value).unwrap(); + Ok(None) + }, + )?; + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); + let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); + // Generate the output shape by filling -1 with `rem` + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &shape_list, + |_, ctx, _| Ok(shape_list.load_size(ctx, None)), + |generator, ctx, shape_list, idx| { + let dim = + shape_list.data().get(ctx, generator, &idx, None).into_int_value(); + let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); + + Ok(gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare( + IntPredicate::SLT, + dim, + llvm_usize.const_zero(), + "", + ) + .unwrap()) + }, + |_, _| Ok(Some(rem)), + |_, _| Ok(Some(dim)), + )? + .unwrap() + .into_int_value()) + }, + ) + } + BasicValueEnum::StructValue(shape_tuple) => { + // 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` + + let ndims = shape_tuple.get_type().count_fields(); + // Check for -1 in dims + for dim_i in 0..ndims { + let dim = ctx + .builder + .build_extract_value(shape_tuple, dim_i, "") + .unwrap() + .into_int_value(); + let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); + + gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare( + IntPredicate::SLT, + dim, + llvm_usize.const_zero(), + "", + ) + .unwrap()) + }, + |_, ctx| -> Result, String> { + let num_negs = + ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); + let num_negs = ctx + .builder + .build_int_add(num_negs, llvm_usize.const_int(1, false), "") + .unwrap(); + ctx.builder.build_store(num_neg, num_negs).unwrap(); + Ok(None) + }, + |_, ctx| { + let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); + let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); + ctx.builder.build_store(acc, acc_val).unwrap(); + Ok(None) + }, + )?; + } + + let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); + let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); + let mut shape = Vec::with_capacity(ndims as usize); + + // Reconstruct shape filling negatives with rem + for dim_i in 0..ndims { + let dim = ctx + .builder + .build_extract_value(shape_tuple, dim_i, "") + .unwrap() + .into_int_value(); + let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); + + let dim = gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare( + IntPredicate::SLT, + dim, + llvm_usize.const_zero(), + "", + ) + .unwrap()) + }, + |_, _| Ok(Some(rem)), + |_, _| Ok(Some(dim)), + )? + .unwrap() + .into_int_value(); + shape.push(dim); + } + create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) + } + BasicValueEnum::IntValue(shape_int) => { + // 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` + let shape_int = gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx + .builder + .build_int_compare( + IntPredicate::SLT, + shape_int, + llvm_usize.const_zero(), + "", + ) + .unwrap()) + }, + |_, _| Ok(Some(n_sz)), + |_, ctx| { + Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap())) + }, + )? + .unwrap() + .into_int_value(); + create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) + } + _ => unreachable!(), + } + .unwrap(); + + // Only allow one dimension to be negative + let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); + ctx.make_assert( + generator, + ctx.builder + .build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "") + .unwrap(), + "0:ValueError", + "can only specify one unknown dimension", + [None, None, None], + ctx.current_loc, + ); + + // The new shape must be compatible with the old shape + let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None)); + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), + "0:ValueError", + "cannot reshape array of size {} into provided shape of size {}", + [Some(n_sz), Some(out_sz), None], + ctx.current_loc, + ); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (n_sz, false), + |generator, ctx, _, idx| { + let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; + unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) }; + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + Ok(out.as_base_value().into()) + } else { + unreachable!( + "{FN_NAME}() not supported for '{}'", + format!("'{}'", ctx.unifier.stringify(x1_ty)) + ) + } +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 0d65828c..18f0be65 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -557,6 +557,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), + PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + self.build_np_sp_ndarray_function(prim) + } + PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul | PrimDef::FunNpLinalgCholesky @@ -1885,6 +1889,57 @@ impl<'a> BuiltinBuilder<'a> { } } + /// Build np/sp functions that take as input `NDArray` only + fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); + + match prim { + PrimDef::FunNpTranspose => { + let ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.ndarray_num_ty], + Some("T".into()), + None, + ); + create_fn_by_codegen( + self.unifier, + &into_var_map([ndarray_ty]), + prim.name(), + ndarray_ty.ty, + &[(ndarray_ty.ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) + }), + ) + } + + // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and + // the `param_ty` for `create_fn_by_codegen`. + // + // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking + // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], + // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. + PrimDef::FunNpReshape => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_num_ty, + &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + }), + ), + + _ => unreachable!(), + } + } + /// Build `np_linalg` and `sp_linalg` functions /// /// The input to these functions must be floating point `NDArray` diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 50deabca..ae17e62c 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -99,6 +99,8 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, + FunNpTranspose, + FunNpReshape, // Linalg functions FunNpDot, @@ -282,6 +284,10 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunNpTranspose => fun("np_transpose", None), + PrimDef::FunNpReshape => fun("np_reshape", None), + + // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None), PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index c8ff7dba..78e19dda 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index b67596d8..b4df49c9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 08f254f5..65a6a8ac 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index ce3b02ed..cfedf1f6 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index ae002764..e84c450a 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index d9380ab1..8e8cba39 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1389,7 +1389,45 @@ impl<'a> Inferencer<'a> { }, })); } + // 2-argument ndarray n-dimensional factory functions + if id == &"np_reshape".into() && args.len() == 2 { + let arg0 = self.fold_expr(args.remove(0))?; + let shape_expr = args.remove(0); + let (ndims, shape) = + self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape` + + let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); + let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap()); + let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims)); + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None }, + FuncArg { + name: "shape".into(), + ty: shape.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![arg0, shape], + keywords: vec![], + }, + })); + } // 2-argument ndarray n-dimensional creation functions if id == &"np_full".into() && args.len() == 2 { let ExprKind::List { elts, .. } = &args[0].node else {