From 5640a793e2fcf98151848cfee9fa831e9d11e06c Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 30 Aug 2024 17:11:38 +0800 Subject: [PATCH 1/3] core: allow np_full to take tuple shapes --- nac3core/src/typecheck/type_inferencer/mod.rs | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a5b8cd49..d27f746e 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1550,36 +1550,29 @@ impl<'a> Inferencer<'a> { } // 2-argument ndarray n-dimensional creation functions if id == &"np_full".into() && args.len() == 2 { - let ExprKind::List { elts, .. } = &args[0].node else { - return report_error( - format!( - "Expected List literal for first argument of {id}, got {}", - args[0].node.name() - ) - .as_str(), - args[0].location, - ); - }; + // Parse arguments + let shape_expr = args.remove(0); + let (ndims, shape) = + self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape` - let ndims = elts.len() as u64; + let fill_value = self.fold_expr(args.remove(0))?; - let arg0 = self.fold_expr(args.remove(0))?; - let arg1 = self.fold_expr(args.remove(0))?; - - let ty = arg1.custom.unwrap(); + // Build the return type + let dtype = fill_value.custom.unwrap(); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); - let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); + let ret = make_ndarray_ty(self.unifier, self.primitives, Some(dtype), Some(ndims)); + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "shape".into(), - ty: arg0.custom.unwrap(), + ty: shape.custom.unwrap(), default_value: None, is_vararg: false, }, FuncArg { name: "fill_value".into(), - ty: arg1.custom.unwrap(), + ty: fill_value.custom.unwrap(), default_value: None, is_vararg: false, }, @@ -1597,7 +1590,7 @@ impl<'a> Inferencer<'a> { location: func.location, node: ExprKind::Name { id: *id, ctx: *ctx }, }), - args: vec![arg0, arg1], + args: vec![shape, fill_value], keywords: vec![], }, })); -- 2.44.2 From 7f629f157970acf6d99a41ebba6ed733e157f5b7 Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 30 Aug 2024 17:12:44 +0800 Subject: [PATCH 2/3] core: fix comment in unify_call --- nac3core/src/typecheck/typedef/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 93ccd9fb..22d04791 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -670,8 +670,8 @@ impl Unifier { let num_args = posargs.len() + kwargs.len(); // Now we check the arguments against the parameters, - // and depending on what `call_info` is, we might change how the behavior `unify_call()` - // in hopes to improve user error messages when type checking fails. + // and depending on what `call_info` is, we might change how `unify_call()` behaves + // to improve user error messages when type checking fails. match operator_info { Some(OperatorInfo::IsBinaryOp { self_type, operator }) => { // The call is written in the form of (say) `a + b`. -- 2.44.2 From 3e92c491f55fccce8d5cfaeba5480b4b9b538b09 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 11 Sep 2024 15:52:43 +0800 Subject: [PATCH 3/3] [standalone] Add tests creating ndarrays with tuple dims --- nac3standalone/demo/src/ndarray.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 9664b3f0..577ad9c3 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -114,12 +114,22 @@ def test_ndarray_ones(): n: ndarray[float, 1] = np_ones([1]) output_ndarray_float_1(n) + dim = (1,) + n_tup: ndarray[float, 1] = np_ones(dim) + output_ndarray_float_1(n_tup) + def test_ndarray_full(): n_float: ndarray[float, 1] = np_full([1], 2.0) output_ndarray_float_1(n_float) n_i32: ndarray[int32, 1] = np_full([1], 2) output_ndarray_int32_1(n_i32) + dim = (1,) + n_float_tup: ndarray[float, 1] = np_full(dim, 2.0) + output_ndarray_float_1(n_float_tup) + n_i32_tup: ndarray[int32, 1] = np_full(dim, 2) + output_ndarray_int32_1(n_i32_tup) + def test_ndarray_eye(): n: ndarray[float, 2] = np_eye(2) output_ndarray_float_2(n) -- 2.44.2