Compare commits
2 Commits
92e626a5c4
...
fb3588b20f
Author | SHA1 | Date |
---|---|---|
lyken | fb3588b20f | |
lyken | 0a9b7aa16b |
|
@ -831,69 +831,59 @@ impl<'a> Inferencer<'a> {
|
|||
/// On success, it returns a tuple of
|
||||
/// 1) the `ndims` value inferred from the input `shape`,
|
||||
/// 2) and the elaborated expression. Like what other fold functions of [`Inferencer`] would normally return.
|
||||
///
|
||||
/// ### Further explanation
|
||||
///
|
||||
/// As said, this function aims to fold `shape` arguments, but this is *not* trivial.
|
||||
/// The root of the issue is that `nac3core` has to deduce the `ndims`
|
||||
/// of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time.
|
||||
///
|
||||
/// There are three types of valid input to `shape`:
|
||||
/// 1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])`
|
||||
/// 2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))`
|
||||
/// 3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])`
|
||||
///
|
||||
/// For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input:
|
||||
/// - For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument.
|
||||
/// - For 3. `ndims` is simply 1.
|
||||
///
|
||||
/// For 1., `ndims` is supposedly the length of the input list. However, the length of the input list
|
||||
/// is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`]
|
||||
/// itself to extract the input list length statically.
|
||||
///
|
||||
/// This implies that the user could only write:
|
||||
///
|
||||
/// ```python
|
||||
/// my_rgba_image = np_zeros([600, 800, 4])
|
||||
/// # the shape argument is directly written as a list literal.
|
||||
/// # and `nac3core` could therefore tell that ndims is `3` by
|
||||
/// # looking at the raw AST expression itself.
|
||||
/// ```
|
||||
///
|
||||
/// But not:
|
||||
///
|
||||
/// ```python
|
||||
/// my_image_dimension = [600, 800, 4]
|
||||
/// mystery_function_that_mutates_my_list(my_image_dimension)
|
||||
/// my_image = np_zeros(my_image_dimension)
|
||||
/// # what is the length now? what is `ndims`?
|
||||
///
|
||||
/// # it is *basically impossible* to generally determine the
|
||||
/// # length of `my_image_dimension` statically for `ndims`!!
|
||||
/// ```
|
||||
fn fold_numpy_function_call_shape_argument(
|
||||
&mut self,
|
||||
id: StrRef,
|
||||
arg_index: usize,
|
||||
shape_expr: Located<ExprKind>,
|
||||
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> {
|
||||
/*
|
||||
### Further explanation
|
||||
|
||||
As said, this function aims to fold `shape` arguments, but this is *not* trivial.
|
||||
The root of the issue is that `nac3core` has to deduce the `ndims`
|
||||
of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time.
|
||||
|
||||
There are three types of valid input to `shape`:
|
||||
1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])`
|
||||
2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))`
|
||||
3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])`
|
||||
|
||||
For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input:
|
||||
- For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument.
|
||||
- For 3. `ndims` is simply 1.
|
||||
|
||||
For 1., `ndims` is supposedly the length of the input list. However, the length of the input list
|
||||
is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`]
|
||||
itself to extract the input list length statically.
|
||||
|
||||
This implies that the user could only write:
|
||||
|
||||
```python
|
||||
my_rgba_image = np_zeros([600, 800, 4])
|
||||
# the shape argument is directly written as a list literal.
|
||||
# and `nac3core` could therefore tell that ndims is `3` by
|
||||
# looking at the raw AST expression itself.
|
||||
```
|
||||
|
||||
But not:
|
||||
|
||||
```python
|
||||
my_image_dimension = [600, 800, 4]
|
||||
mystery_function_that_mutates_my_list(my_image_dimension)
|
||||
my_image = np_zeros(my_image_dimension)
|
||||
# what is the length now? what is `ndims`?
|
||||
|
||||
# it is *basically impossible* to generally determine the
|
||||
# length of `my_image_dimension` statically for `ndims`!!
|
||||
```
|
||||
*/
|
||||
|
||||
// Auxillary details for error reporting.
|
||||
// Predefined here because `shape_expr` will be moved when doing `fold_expr`
|
||||
let shape_expr_name = shape_expr.node.name();
|
||||
let shape_location = shape_expr.location;
|
||||
|
||||
// The deduced `ndims` of `shape`. To be determined.
|
||||
//
|
||||
// Also the rust compiler is not smart enough to tell `ndims` must be initialized when used
|
||||
// in this block of code. So we do `Option` + `unwrap`.
|
||||
let mut ndims: Option<u64> = None;
|
||||
|
||||
// Special handling for (1. A python `List` (all `int32s`)).
|
||||
// Read the doc above this function to see what is going on here.
|
||||
if let ExprKind::List { elts, .. } = &shape_expr.node {
|
||||
ndims = Some(elts.len() as u64);
|
||||
}
|
||||
|
||||
// Fold `shape`
|
||||
let shape = self.fold_expr(shape_expr)?;
|
||||
let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
|
||||
|
@ -903,7 +893,7 @@ impl<'a> Inferencer<'a> {
|
|||
//
|
||||
// Here, we also take the opportunity to deduce `ndims` statically for 2. and 3.
|
||||
let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
|
||||
match shape_ty_enum {
|
||||
let ndims = match shape_ty_enum {
|
||||
TypeEnum::TList { ty } => {
|
||||
// Handle 1. A list of int32s
|
||||
|
||||
|
@ -915,13 +905,15 @@ impl<'a> Inferencer<'a> {
|
|||
.to_string()])
|
||||
})?;
|
||||
|
||||
// Special handling: when nac3core reaches this line of code, two things can happen:
|
||||
// Case 1: `ndims` is `Some`, this means the previous "Special handling for 1." ran,
|
||||
// and the user wrote a List literal as the input argument.
|
||||
// Case 2: `ndims` is `None`, this means the user is passing an expression of type `List`,
|
||||
// but it is done so indirectly (like putting a variable referencing a `List`)
|
||||
// rather than writing a List literal. We need to report an error.
|
||||
if ndims.is_none() {
|
||||
// Special handling for (1. A python `List` (all `int32s`)).
|
||||
// Read the doc above this function to see what is going on here.
|
||||
if let ExprKind::List { elts, .. } = &shape.node {
|
||||
// The user wrote a List literal as the input argument
|
||||
elts.len() as u64
|
||||
} else {
|
||||
// This means the user is passing an expression of type `List`,
|
||||
// but it is done so indirectly (like putting a variable referencing a `List`)
|
||||
// rather than writing a List literal. We need to report an error.
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
|
||||
|
@ -949,7 +941,7 @@ impl<'a> Inferencer<'a> {
|
|||
})?;
|
||||
|
||||
// `ndims` can be deduced statically from the inferred Tuple type.
|
||||
ndims = Some(tuple_element_types.len() as u64);
|
||||
tuple_element_types.len() as u64
|
||||
}
|
||||
TypeEnum::TObj { .. } => {
|
||||
// Handle 3. An integer (generalized as [`TypeEnum::TObj`])
|
||||
|
@ -958,7 +950,7 @@ impl<'a> Inferencer<'a> {
|
|||
self.unify(self.primitives.int32, shape_ty, &shape_location)?;
|
||||
|
||||
// Deduce `ndims`
|
||||
ndims = Some(1);
|
||||
1
|
||||
}
|
||||
_ => {
|
||||
// The user wrote an ill-typed `shape_expr`,
|
||||
|
@ -972,9 +964,8 @@ impl<'a> Inferencer<'a> {
|
|||
shape_location,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let ndims = ndims.unwrap_or_else(|| unreachable!("ndims should be initialized"));
|
||||
Ok((ndims, shape))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue