core: support tuple and int32 input for np_empty, np_ones, and more #434

Merged
sb10q merged 1 commits from ndfactory-tuple into master 2024-08-17 17:37:21 +08:00
Collaborator

Fixes #427 + some additions.

Originally nac3core only supports list literals as the shape parameter for numpy factory functions (e.g., np_zeros([600, 800, 3])). These functions include np_ndarray, np_zeros, np_ones, and np_full.

This PR extends the shape parameter to allow:

  1. shape by a int32 scalar. (e.g., np_zeros(256))
  2. shape by a tuple of int32 of arbitrary length. (e.g., np_zeros((600, 800, 3)), np_zeros((4096,)))

Both need not to be written as a literal to be a valid input to shape like the restriction that is put in place for list inputs as nac3core has to statically deduce the ndims typevar of the returned ndarray. That is:

### Current nac3core's support for the `shape` parameter
np_zeros([600, 800, 3]) # ok

my_dim = [600, 800, 3]
np_zeros(my_dim) # BAD
# `my_dim`'s length is hard to know statically in general
# to deduce the `ndims` typevar of the returned ndarray.

### This PR now adds supports for:
my_number = 64
np_zeros(4096) # ok, this is functionally equivalent to `np_zeros([4096])`
np_zeros(my_number) # ok
np_zeros((1024,)) # ok
np_zeros((1080, 1920, 4)) # ok
np_zeros((1080, 1920, my_number)) # ok

my_dim = (600, 800, my_number * my_number)
np_zeros(my_dim) # ok
Fixes https://git.m-labs.hk/M-Labs/nac3/issues/427 + some additions. Originally `nac3core` only supports list literals as the `shape` parameter for numpy factory functions (e.g., `np_zeros([600, 800, 3])`). These functions include `np_ndarray`, `np_zeros`, `np_ones`, and `np_full`. This PR extends the `shape` parameter to allow: 1. `shape` by a `int32` scalar. (e.g., `np_zeros(256)`) 2. `shape` by a tuple of `int32` of arbitrary length. (e.g., `np_zeros((600, 800, 3))`, `np_zeros((4096,))`) Both need not to be written as a literal to be a valid input to `shape` like the restriction that is put in place for list inputs as `nac3core` has to statically deduce the `ndims` typevar of the returned ndarray. That is: ```python ### Current nac3core's support for the `shape` parameter np_zeros([600, 800, 3]) # ok my_dim = [600, 800, 3] np_zeros(my_dim) # BAD # `my_dim`'s length is hard to know statically in general # to deduce the `ndims` typevar of the returned ndarray. ### This PR now adds supports for: my_number = 64 np_zeros(4096) # ok, this is functionally equivalent to `np_zeros([4096])` np_zeros(my_number) # ok np_zeros((1024,)) # ok np_zeros((1080, 1920, 4)) # ok np_zeros((1080, 1920, my_number)) # ok my_dim = (600, 800, my_number * my_number) np_zeros(my_dim) # ok ```
lyken requested review from derppening 2024-06-25 15:53:13 +08:00
derppening requested changes 2024-06-26 18:56:28 +08:00
Dismissed
@ -823,0 +832,4 @@
/// 1) the `ndims` value inferred from the input `shape`,
/// 2) and the elaborated expression. Like what other fold functions of [`Inferencer`] would normally return.
///
/// ### Further explanation
Collaborator

I am quite torn on whether "Further Explanation" should be part of the API documentation. I am of the opinion that this should instead belong to comments within the method body, as it is not somethings that callers of the API would need to be concerned about.

I am quite torn on whether "Further Explanation" should be part of the API documentation. I am of the opinion that this should instead belong to comments within the method body, as it is not somethings that callers of the API would need to be concerned about.
Author
Collaborator

Okay, I will move it into the implementation.

Okay, I will move it into the implementation.
derppening marked this conversation as resolved
@ -823,0 +886,4 @@
//
// Also the rust compiler is not smart enough to tell `ndims` must be initialized when used
// in this block of code. So we do `Option` + `unwrap`.
let mut ndims: Option<u64> = None;
Collaborator

I think this can be rewritten this way:

===================================================================
diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs
--- a/nac3core/src/typecheck/type_inferencer/mod.rs	(revision 92e626a5c41339240a89f50f4e39cbb325d1f5fa)
+++ b/nac3core/src/typecheck/type_inferencer/mod.rs	(date 1719399227416)
@@ -882,18 +882,6 @@
         let shape_expr_name = shape_expr.node.name();
         let shape_location = shape_expr.location;
 
-        // The deduced `ndims` of `shape`. To be determined.
-        //
-        // Also the rust compiler is not smart enough to tell `ndims` must be initialized when used
-        // in this block of code. So we do `Option` + `unwrap`.
-        let mut ndims: Option<u64> = None;
-
-        // Special handling for (1. A python `List` (all `int32s`)).
-        // Read the doc above this function to see what is going on here.
-        if let ExprKind::List { elts, .. } = &shape_expr.node {
-            ndims = Some(elts.len() as u64);
-        }
-
         // Fold `shape`
         let shape = self.fold_expr(shape_expr)?;
         let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
@@ -903,7 +891,7 @@
         //
         // Here, we also take the opportunity to deduce `ndims` statically for 2. and 3.
         let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
-        match shape_ty_enum {
+        let ndims = match shape_ty_enum {
             TypeEnum::TList { ty } => {
                 // Handle 1. A list of int32s
 
@@ -915,13 +903,11 @@
                         .to_string()])
                 })?;
 
-                // Special handling: when nac3core reaches this line of code, two things can happen:
-                //   Case 1: `ndims` is `Some`, this means the previous "Special handling for 1." ran,
-                //           and the user wrote a List literal as the input argument.
-                //   Case 2: `ndims` is `None`, this means the user is passing an expression of type `List`,
-                //           but it is done so indirectly (like putting a variable referencing a `List`)
-                //           rather than writing a List literal. We need to report an error.
-                if ndims.is_none() {
+                // Special handling for (1. A python `List` (all `int32s`)).
+                // Read the doc above this function to see what is going on here.
+                if let ExprKind::List { elts, .. } = &shape.node {
+                    elts.len() as u64
+                } else {
                     return Err(HashSet::from([
                         format!(
                             "Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
@@ -949,7 +935,7 @@
                 })?;
 
                 // `ndims` can be deduced statically from the inferred Tuple type.
-                ndims = Some(tuple_element_types.len() as u64);
+                tuple_element_types.len() as u64
             }
             TypeEnum::TObj { .. } => {
                 // Handle 3. An integer (generalized as [`TypeEnum::TObj`])
@@ -958,7 +944,7 @@
                 self.unify(self.primitives.int32, shape_ty, &shape_location)?;
 
                 // Deduce `ndims`
-                ndims = Some(1);
+                1
             }
             _ => {
                 // The user wrote an ill-typed `shape_expr`,
@@ -972,9 +958,8 @@
                     shape_location,
                 );
             }
-        }
+        };
 
-        let ndims = ndims.unwrap_or_else(|| unreachable!("ndims should be initialized"));
         Ok((ndims, shape))
     }
 

That way ndims would be initialized once.

I think this can be rewritten this way: ```diff =================================================================== diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs --- a/nac3core/src/typecheck/type_inferencer/mod.rs (revision 92e626a5c41339240a89f50f4e39cbb325d1f5fa) +++ b/nac3core/src/typecheck/type_inferencer/mod.rs (date 1719399227416) @@ -882,18 +882,6 @@ let shape_expr_name = shape_expr.node.name(); let shape_location = shape_expr.location; - // The deduced `ndims` of `shape`. To be determined. - // - // Also the rust compiler is not smart enough to tell `ndims` must be initialized when used - // in this block of code. So we do `Option` + `unwrap`. - let mut ndims: Option<u64> = None; - - // Special handling for (1. A python `List` (all `int32s`)). - // Read the doc above this function to see what is going on here. - if let ExprKind::List { elts, .. } = &shape_expr.node { - ndims = Some(elts.len() as u64); - } - // Fold `shape` let shape = self.fold_expr(shape_expr)?; let shape_ty = shape.custom.unwrap(); // The inferred type of `shape` @@ -903,7 +891,7 @@ // // Here, we also take the opportunity to deduce `ndims` statically for 2. and 3. let shape_ty_enum = &*self.unifier.get_ty(shape_ty); - match shape_ty_enum { + let ndims = match shape_ty_enum { TypeEnum::TList { ty } => { // Handle 1. A list of int32s @@ -915,13 +903,11 @@ .to_string()]) })?; - // Special handling: when nac3core reaches this line of code, two things can happen: - // Case 1: `ndims` is `Some`, this means the previous "Special handling for 1." ran, - // and the user wrote a List literal as the input argument. - // Case 2: `ndims` is `None`, this means the user is passing an expression of type `List`, - // but it is done so indirectly (like putting a variable referencing a `List`) - // rather than writing a List literal. We need to report an error. - if ndims.is_none() { + // Special handling for (1. A python `List` (all `int32s`)). + // Read the doc above this function to see what is going on here. + if let ExprKind::List { elts, .. } = &shape.node { + elts.len() as u64 + } else { return Err(HashSet::from([ format!( "Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \ @@ -949,7 +935,7 @@ })?; // `ndims` can be deduced statically from the inferred Tuple type. - ndims = Some(tuple_element_types.len() as u64); + tuple_element_types.len() as u64 } TypeEnum::TObj { .. } => { // Handle 3. An integer (generalized as [`TypeEnum::TObj`]) @@ -958,7 +944,7 @@ self.unify(self.primitives.int32, shape_ty, &shape_location)?; // Deduce `ndims` - ndims = Some(1); + 1 } _ => { // The user wrote an ill-typed `shape_expr`, @@ -972,9 +958,8 @@ shape_location, ); } - } + }; - let ndims = ndims.unwrap_or_else(|| unreachable!("ndims should be initialized")); Ok((ndims, shape)) } ``` That way `ndims` would be initialized once.
Author
Collaborator

Ah I didn't think to do that! Thanks.

Ah I didn't think to do that! Thanks.
derppening marked this conversation as resolved
@ -823,0 +924,4 @@
if ndims.is_none() {
return Err(HashSet::from([
format!(
"Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
Collaborator

Call it a "List Literal".

Call it a "List Literal".
derppening marked this conversation as resolved
@ -823,0 +927,4 @@
"Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
There, you are passing a value of type List as the argument. \
However, this argument is special - you must only supply this argument with a List literal. \
On the other hand, you may instead pass in a tuple, and there would be no such restriction.",
Collaborator

Is this long explanation necessary?

Is this long explanation necessary?
Author
Collaborator

I wrote it this long is because I think this is quite a notable quirk of NAC3 - and these factory functions are probably going to be frequently used by many people. Nonetheless, I would like your opinions on how I could improve this.

I wrote it this long is because I think this is quite a notable quirk of NAC3 - and these factory functions are probably going to be frequently used by many people. Nonetheless, I would like your opinions on how I could improve this.
Collaborator
Expected list literal, tuple, or int32 for argument {arg_num} of {id} at {shape_location}.

Note: shape must be a compile-time constant.
``` Expected list literal, tuple, or int32 for argument {arg_num} of {id} at {shape_location}. Note: shape must be a compile-time constant. ```
Author
Collaborator

I am not entirely sure what "compile-time constants" means exactly. I guess that means a "static"-ishly defined value. shape need not to be a compile-time constant for tuple and int32 in my implementation to determine ndims - unless I am misunderstanding something horribly...

I am not entirely sure what "compile-time constants" means exactly. I guess that means a "static"-ishly defined value. `shape` need not to be a compile-time constant for `tuple` and `int32` in my implementation to determine `ndims` - unless I am misunderstanding something horribly...
Author
Collaborator

Also for the record, here is what nac3core (before and also now) allows in my understanding:

np_zeros([10, 20, 30]) # ok

MY_SHAPE = [10, 20, 30]
np_zeros(MY_SHAPE) # not ok, even though one might say `MY_SHAPE` is a compile-time constant.

a = 10
b = some_weird_function(a)
np_zeros([a, b]) # this is ok, the `shape` argument just have to be a list literal.
Also for the record, here is what `nac3core` (before and also now) allows in my understanding: ```python np_zeros([10, 20, 30]) # ok MY_SHAPE = [10, 20, 30] np_zeros(MY_SHAPE) # not ok, even though one might say `MY_SHAPE` is a compile-time constant. a = 10 b = some_weird_function(a) np_zeros([a, b]) # this is ok, the `shape` argument just have to be a list literal. ```
Author
Collaborator

For now I would replace the message to be

Expected list literal, tuple, or int32 for argument {arg_num} of {id} at {location}. Input argument is of type list but not written as a list literal.
For now I would replace the message to be ``` Expected list literal, tuple, or int32 for argument {arg_num} of {id} at {location}. Input argument is of type list but not written as a list literal. ```
derppening marked this conversation as resolved
Author
Collaborator

Revised.

Revised.
lyken force-pushed ndfactory-tuple from fb3588b20f to ee5389f91b 2024-06-27 10:09:14 +08:00 Compare
derppening reviewed 2024-06-27 13:47:51 +08:00
@ -824,0 +883,4 @@
// Auxillary details for error reporting.
// Predefined here because `shape_expr` will be moved when doing `fold_expr`
let shape_expr_name = shape_expr.node.name();
let shape_location = shape_expr.location;
Collaborator

On second thought, these shouldn't be necessary as shape will contain these information AFAIK.

On second thought, these shouldn't be necessary as `shape` will contain these information AFAIK.
derppening marked this conversation as resolved
lyken force-pushed ndfactory-tuple from 621ed6382c to 5b11a1dbdd 2024-06-27 14:31:06 +08:00 Compare
Author
Collaborator

Further revision and squashed history.

Further revision and squashed history.
lyken added 1 commit 2024-06-27 14:47:09 +08:00
derppening approved these changes 2024-06-27 14:53:41 +08:00
sb10q was assigned by derppening 2024-06-27 14:53:46 +08:00
sb10q merged commit 9808923258 into master 2024-06-27 14:54:55 +08:00
sb10q deleted branch ndfactory-tuple 2024-06-27 14:54:56 +08:00
Sign in to join this conversation.
No reviewers
No Milestone
No Assignees
2 Participants
Notifications
Due Date
The due date is invalid or out of range. Please use the format 'yyyy-mm-dd'.

No due date set.

Dependencies

No dependencies set.

Reference: M-Labs/nac3#434
No description provided.