From 5b11a1dbdd229b6f7943091552fac09882aa7393 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 25 Jun 2024 15:35:02 +0800 Subject: [PATCH] core: support tuple and int32 input for np_empty, np_ones, and more --- nac3core/src/codegen/numpy.rs | 125 +++++++------ nac3core/src/toplevel/builtins.rs | 40 ++++- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/type_inferencer/mod.rs | 166 ++++++++++++++++-- nac3standalone/demo/src/ndarray.py | 35 +++- 9 files changed, 298 insertions(+), 82 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 7f19f4ed..3fab259e 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -163,10 +163,11 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { let llvm_usize = generator.get_size_type(ctx.ctx); - for shape_dim in shape { + for &shape_dim in shape { + let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let shape_dim_gez = ctx .builder - .build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "") + .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") .unwrap(); ctx.make_assert( @@ -189,7 +190,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( let ndarray_num_dims = ndarray.load_ndims(ctx); ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - for (i, shape_dim) in shape.iter().enumerate() { + for (i, &shape_dim) in shape.iter().enumerate() { + let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let ndarray_dim = unsafe { ndarray.dim_sizes().ptr_offset_unchecked( ctx, @@ -199,7 +201,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( ) }; - ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap(); + ctx.builder.build_store(ndarray_dim, shape_dim).unwrap(); } let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); @@ -286,22 +288,68 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( /// /// * `elem_ty` - The element type of the `NDArray`. /// * `shape` - The `shape` parameter used to construct the `NDArray`. +/// +/// ### Notes on `shape` +/// +/// Just like numpy, the `shape` argument can be: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to +/// learn how `shape` gets from being a Python user expression to here. fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, - shape: ListValue<'ctx>, + shape: BasicValueEnum<'ctx>, ) -> Result, String> { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape, - |_, ctx, shape| Ok(shape.load_size(ctx, None)), - |generator, ctx, shape, idx| { - Ok(shape.data().get(ctx, generator, &idx, None).into_int_value()) - }, - ) + let llvm_usize = generator.get_size_type(ctx.ctx); + + match shape { + BasicValueEnum::PointerValue(shape_list_ptr) + if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => + { + // 1. A list of ints; e.g., `np.empty([600, 800, 3])` + + let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &shape_list, + |_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)), + |generator, ctx, shape_list, idx| { + Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value()) + }, + ) + } + BasicValueEnum::StructValue(shape_tuple) => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. + + // Get the length/size of the tuple, which also happens to be the value of `ndims`. + let ndims = shape_tuple.get_type().count_fields(); + + let mut shape = Vec::with_capacity(ndims as usize); + for dim_i in 0..ndims { + let dim = ctx + .builder + .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) + .unwrap() + .into_int_value(); + + shape.push(dim); + } + create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) + } + BasicValueEnum::IntValue(shape_int) => { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) + } + _ => unreachable!(), + } } /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as @@ -486,7 +534,7 @@ fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, - shape: ListValue<'ctx>, + shape: BasicValueEnum<'ctx>, ) -> Result, String> { let supported_types = [ ctx.primitives.int32, @@ -517,7 +565,7 @@ fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, - shape: ListValue<'ctx>, + shape: BasicValueEnum<'ctx>, ) -> Result, String> { let supported_types = [ ctx.primitives.int32, @@ -548,7 +596,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, - shape: ListValue<'ctx>, + shape: BasicValueEnum<'ctx>, fill_value: BasicValueEnum<'ctx>, ) -> Result, String> { let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; @@ -1674,17 +1722,11 @@ pub fn gen_ndarray_empty<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl( - generator, - context, - context.primitives.float, - ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) - .map(NDArrayValue::into) + call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1698,17 +1740,11 @@ pub fn gen_ndarray_zeros<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl( - generator, - context, - context.primitives.float, - ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) - .map(NDArrayValue::into) + call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.ones`. @@ -1722,17 +1758,11 @@ pub fn gen_ndarray_ones<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl( - generator, - context, - context.primitives.float, - ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - ) - .map(NDArrayValue::into) + call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) + .map(NDArrayValue::into) } /// Generates LLVM IR for `ndarray.full`. @@ -1746,21 +1776,14 @@ pub fn gen_ndarray_full<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 2); - let llvm_usize = generator.get_size_type(context.ctx); let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; let fill_value_ty = fun.0.args[1].ty; let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl( - generator, - context, - fill_value_ty, - ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None), - fill_value_arg, - ) - .map(NDArrayValue::into) + call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) + .map(NDArrayValue::into) } pub fn gen_ndarray_array<'ctx>( diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 5531bddc..2524b9a7 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -324,6 +324,9 @@ struct BuiltinBuilder<'a> { num_or_ndarray_ty: TypeVar, num_or_ndarray_var_map: VarMap, + + /// See [`BuiltinBuilder::build_ndarray_from_shape_factory_function`] + ndarray_factory_fn_shape_arg_tvar: TypeVar, } impl<'a> BuiltinBuilder<'a> { @@ -394,6 +397,8 @@ impl<'a> BuiltinBuilder<'a> { let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); + let ndarray_factory_fn_shape_arg_tvar = unifier.get_fresh_var(Some("Shape".into()), None); + BuiltinBuilder { unifier, primitives, @@ -421,6 +426,8 @@ impl<'a> BuiltinBuilder<'a> { num_or_ndarray_ty, num_or_ndarray_var_map, + + ndarray_factory_fn_shape_arg_tvar, } } @@ -959,21 +966,46 @@ impl<'a> BuiltinBuilder<'a> { ) } - /// Build ndarray factory functions that only take in an argument `shape` of type `list[int32]` and return an ndarray. + /// Build ndarray factory functions that only take in an argument `shape`. + /// + /// `shape` can be a tuple of int32s, a list of int32s, or a scalar int32. fn build_ndarray_from_shape_factory_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed( prim, &[PrimDef::FunNpNDArray, PrimDef::FunNpEmpty, PrimDef::FunNpZeros, PrimDef::FunNpOnes], ); + // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and + // the `param_ty` for `create_fn_by_codegen`. + // + // Ideally, we should have created a [`TypeVar`] to define all possible input + // types for the parameter "shape" like so: + // ```rust + // self.unifier.get_fresh_var_with_range( + // &[int32, list_int32, /* and more... */], + // Some("T".into()), None) + // ) + // ``` + // + // However, there is (currently) no way to type a tuple of arbitrary length in `nac3core`. + // + // And this is the best we could do: + // ```rust + // &[ int32, list_int32, tuple_1_int32, tuple_2_int32, tuple_3_int32, ... ], + // ``` + // + // But this is not ideal. + // + // Instead, we delegate the responsibility of typechecking + // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], + // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. + create_fn_by_codegen( self.unifier, &VarMap::new(), prim.name(), self.ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(self.list_int32, "shape")], + &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], Box::new(move |ctx, obj, fun, args, generator| { let func = match prim { PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 33e44335..82ab00e4 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(240)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 7b8e1953..d3301d00 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar228]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar228\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 8f3d9bf9..911426b9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 18ad4aa4..d60daf83 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar227, typevar228]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar227\", \"typevar228\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index d58f5f1c..0fea9e25 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 2e14d155..6c24400c 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -814,6 +814,150 @@ impl<'a> Inferencer<'a> { }) } + /// Fold an ndarray `shape` argument. This function aims to fold `shape` arguments like that of + /// (for `np_zeros`). + /// + /// Arguments: + /// * `id` - The name of the function of the function call this `shape` argument is in. Used for error reporting. + /// * `arg_index` - The position (0-based) of this argument in the function call. Used for error reporting. + /// * `shape_expr` - [`Located`] of the input argument. + /// + /// On success, it returns a tuple of + /// 1) the `ndims` value inferred from the input `shape`, + /// 2) and the elaborated expression. Like what other fold functions of [`Inferencer`] would normally return. + fn fold_numpy_function_call_shape_argument( + &mut self, + id: StrRef, + arg_index: usize, + shape_expr: Located, + ) -> Result<(u64, ast::Expr>), HashSet> { + /* + ### Further explanation + + As said, this function aims to fold `shape` arguments, but this is *not* trivial. + The root of the issue is that `nac3core` has to deduce the `ndims` + of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time. + + There are three types of valid input to `shape`: + 1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])` + 2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))` + 3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])` + + For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input: + - For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument. + - For 3. `ndims` is simply 1. + + For 1., `ndims` is supposedly the length of the input list. However, the length of the input list + is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`] + itself to extract the input list length statically. + + This implies that the user could only write: + + ```python + my_rgba_image = np_zeros([600, 800, 4]) + # the shape argument is directly written as a list literal. + # and `nac3core` could therefore tell that ndims is `3` by + # looking at the raw AST expression itself. + ``` + + But not: + + ```python + my_image_dimension = [600, 800, 4] + mystery_function_that_mutates_my_list(my_image_dimension) + my_image = np_zeros(my_image_dimension) + # what is the length now? what is `ndims`? + + # it is *basically impossible* to generally determine the + # length of `my_image_dimension` statically for `ndims`!! + ``` + */ + + // Fold `shape` + let shape = self.fold_expr(shape_expr)?; + let shape_ty = shape.custom.unwrap(); // The inferred type of `shape` + + // Check `shape_ty` to see if its a list of int32s, a tuple of int32s, or just int32. + // Otherwise throw an error as that would mean the user wrote an ill-typed `shape_expr`. + // + // Here, we also take the opportunity to deduce `ndims` statically for 2. and 3. + let shape_ty_enum = &*self.unifier.get_ty(shape_ty); + let ndims = match shape_ty_enum { + TypeEnum::TList { ty } => { + // Handle 1. A list of int32s + + // Typecheck + self.unifier.unify(*ty, self.primitives.int32).map_err(|err| { + HashSet::from([err + .at(Some(shape.location)) + .to_display(self.unifier) + .to_string()]) + })?; + + // Special handling for (1. A python `List` (all `int32s`)). + // Read the doc above this function to see what is going on here. + if let ExprKind::List { elts, .. } = &shape.node { + // The user wrote a List literal as the input argument + elts.len() as u64 + } else { + // This means the user is passing an expression of type `List`, + // but it is done so indirectly (like putting a variable referencing a `List`) + // rather than writing a List literal. We need to report an error. + return Err(HashSet::from([ + format!( + "Expected list literal, tuple, or int32 for argument {arg_num} of {id} at {location}. Input argument is of type list but not a list literal.", + arg_num = arg_index + 1, + location = shape.location + ) + ])); + } + } + TypeEnum::TTuple { ty: tuple_element_types } => { + // Handle 2. A tuple of int32s + + // Typecheck + // The expected type is just the tuple but with all its elements being int32. + let expected_ty = self.unifier.add_ty(TypeEnum::TTuple { + ty: tuple_element_types.iter().map(|_| self.primitives.int32).collect_vec(), + }); + self.unifier.unify(shape_ty, expected_ty).map_err(|err| { + HashSet::from([err + .at(Some(shape.location)) + .to_display(self.unifier) + .to_string()]) + })?; + + // `ndims` can be deduced statically from the inferred Tuple type. + tuple_element_types.len() as u64 + } + TypeEnum::TObj { .. } => { + // Handle 3. An integer (generalized as [`TypeEnum::TObj`]) + + // Typecheck + self.unify(self.primitives.int32, shape_ty, &shape.location)?; + + // Deduce `ndims` + 1 + } + _ => { + // The user wrote an ill-typed `shape_expr`, + // so throw an error. + let shape_ty_str = self.unifier.stringify(shape_ty); + return report_error( + format!( + "Expected list literal, tuple, or int32 for argument {arg_num} of {id}, got {shape_expr_name} of type {shape_ty_str}", + arg_num = arg_index + 1, + shape_expr_name = shape.node.name(), + ) + .as_str(), + shape.location, + ); + } + }; + + Ok((ndims, shape)) + } + /// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise /// returns [`None`]. fn try_fold_special_call( @@ -1141,25 +1285,15 @@ impl<'a> Inferencer<'a> { })); } - // 1-argument ndarray n-dimensional creation functions + // 1-argument ndarray n-dimensional factory functions if ["np_ndarray".into(), "np_empty".into(), "np_zeros".into(), "np_ones".into()] .contains(id) && args.len() == 1 { - let ExprKind::List { elts, .. } = &args[0].node else { - return report_error( - format!( - "Expected List literal for first argument of {id}, got {}", - args[0].node.name() - ) - .as_str(), - args[0].location, - ); - }; + let shape_expr = args.remove(0); + let (ndims, shape) = + self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling the `shape` - let ndims = elts.len() as u64; - - let arg0 = self.fold_expr(args.remove(0))?; let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ret = make_ndarray_ty( self.unifier, @@ -1170,7 +1304,7 @@ impl<'a> Inferencer<'a> { let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "shape".into(), - ty: arg0.custom.unwrap(), + ty: shape.custom.unwrap(), default_value: None, }], ret, @@ -1186,7 +1320,7 @@ impl<'a> Inferencer<'a> { location: func.location, node: ExprKind::Name { id: *id, ctx: *ctx }, }), - args: vec![arg0], + args: vec![shape], keywords: vec![], }, })); diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index eb124351..1398b0d0 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -71,17 +71,44 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass +def consume_ndarray_2(n: ndarray[float, Literal[2]]): + pass + def test_ndarray_ctor(): n: ndarray[float, Literal[1]] = np_ndarray([1]) consume_ndarray_1(n) def test_ndarray_empty(): - n: ndarray[float, 1] = np_empty([1]) - consume_ndarray_1(n) + n1: ndarray[float, 1] = np_empty([1]) + consume_ndarray_1(n1) + + n2: ndarray[float, 1] = np_empty(10) + consume_ndarray_1(n2) + + n3: ndarray[float, 1] = np_empty((2,)) + consume_ndarray_1(n3) + + n4: ndarray[float, 2] = np_empty((4, 4)) + consume_ndarray_2(n4) + + dim4 = (5, 2) + n5: ndarray[float, 2] = np_empty(dim4) + consume_ndarray_2(n5) def test_ndarray_zeros(): - n: ndarray[float, 1] = np_zeros([1]) - output_ndarray_float_1(n) + n1: ndarray[float, 1] = np_zeros([1]) + output_ndarray_float_1(n1) + + k = 3 + int32(n1[0]) # to test variable shape inputs + n2: ndarray[float, 1] = np_zeros(k * k) + output_ndarray_float_1(n2) + + n3: ndarray[float, 1] = np_zeros((k * 2,)) + output_ndarray_float_1(n3) + + dim4 = (3, 2 * k) + n4: ndarray[float, 2] = np_zeros(dim4) + output_ndarray_float_2(n4) def test_ndarray_ones(): n: ndarray[float, 1] = np_ones([1])