Compare commits

...

2 Commits

Author SHA1 Message Date
lyken fb3588b20f core: refactor fold_numpy_function_call_shape_argument 2024-06-27 10:06:46 +08:00
lyken 0a9b7aa16b core: move comment 2024-06-27 09:50:47 +08:00
1 changed files with 55 additions and 64 deletions

View File

@ -831,69 +831,59 @@ impl<'a> Inferencer<'a> {
/// On success, it returns a tuple of
/// 1) the `ndims` value inferred from the input `shape`,
/// 2) and the elaborated expression. Like what other fold functions of [`Inferencer`] would normally return.
///
/// ### Further explanation
///
/// As said, this function aims to fold `shape` arguments, but this is *not* trivial.
/// The root of the issue is that `nac3core` has to deduce the `ndims`
/// of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time.
///
/// There are three types of valid input to `shape`:
/// 1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])`
/// 2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))`
/// 3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])`
///
/// For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input:
/// - For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument.
/// - For 3. `ndims` is simply 1.
///
/// For 1., `ndims` is supposedly the length of the input list. However, the length of the input list
/// is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`]
/// itself to extract the input list length statically.
///
/// This implies that the user could only write:
///
/// ```python
/// my_rgba_image = np_zeros([600, 800, 4])
/// # the shape argument is directly written as a list literal.
/// # and `nac3core` could therefore tell that ndims is `3` by
/// # looking at the raw AST expression itself.
/// ```
///
/// But not:
///
/// ```python
/// my_image_dimension = [600, 800, 4]
/// mystery_function_that_mutates_my_list(my_image_dimension)
/// my_image = np_zeros(my_image_dimension)
/// # what is the length now? what is `ndims`?
///
/// # it is *basically impossible* to generally determine the
/// # length of `my_image_dimension` statically for `ndims`!!
/// ```
fn fold_numpy_function_call_shape_argument(
&mut self,
id: StrRef,
arg_index: usize,
shape_expr: Located<ExprKind>,
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> {
/*
### Further explanation
As said, this function aims to fold `shape` arguments, but this is *not* trivial.
The root of the issue is that `nac3core` has to deduce the `ndims`
of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time.
There are three types of valid input to `shape`:
1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])`
2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))`
3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])`
For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input:
- For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument.
- For 3. `ndims` is simply 1.
For 1., `ndims` is supposedly the length of the input list. However, the length of the input list
is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`]
itself to extract the input list length statically.
This implies that the user could only write:
```python
my_rgba_image = np_zeros([600, 800, 4])
# the shape argument is directly written as a list literal.
# and `nac3core` could therefore tell that ndims is `3` by
# looking at the raw AST expression itself.
```
But not:
```python
my_image_dimension = [600, 800, 4]
mystery_function_that_mutates_my_list(my_image_dimension)
my_image = np_zeros(my_image_dimension)
# what is the length now? what is `ndims`?
# it is *basically impossible* to generally determine the
# length of `my_image_dimension` statically for `ndims`!!
```
*/
// Auxillary details for error reporting.
// Predefined here because `shape_expr` will be moved when doing `fold_expr`
let shape_expr_name = shape_expr.node.name();
let shape_location = shape_expr.location;
// The deduced `ndims` of `shape`. To be determined.
//
// Also the rust compiler is not smart enough to tell `ndims` must be initialized when used
// in this block of code. So we do `Option` + `unwrap`.
let mut ndims: Option<u64> = None;
// Special handling for (1. A python `List` (all `int32s`)).
// Read the doc above this function to see what is going on here.
if let ExprKind::List { elts, .. } = &shape_expr.node {
ndims = Some(elts.len() as u64);
}
// Fold `shape`
let shape = self.fold_expr(shape_expr)?;
let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
@ -903,7 +893,7 @@ impl<'a> Inferencer<'a> {
//
// Here, we also take the opportunity to deduce `ndims` statically for 2. and 3.
let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
match shape_ty_enum {
let ndims = match shape_ty_enum {
TypeEnum::TList { ty } => {
// Handle 1. A list of int32s
@ -915,13 +905,15 @@ impl<'a> Inferencer<'a> {
.to_string()])
})?;
// Special handling: when nac3core reaches this line of code, two things can happen:
// Case 1: `ndims` is `Some`, this means the previous "Special handling for 1." ran,
// and the user wrote a List literal as the input argument.
// Case 2: `ndims` is `None`, this means the user is passing an expression of type `List`,
// but it is done so indirectly (like putting a variable referencing a `List`)
// rather than writing a List literal. We need to report an error.
if ndims.is_none() {
// Special handling for (1. A python `List` (all `int32s`)).
// Read the doc above this function to see what is going on here.
if let ExprKind::List { elts, .. } = &shape.node {
// The user wrote a List literal as the input argument
elts.len() as u64
} else {
// This means the user is passing an expression of type `List`,
// but it is done so indirectly (like putting a variable referencing a `List`)
// rather than writing a List literal. We need to report an error.
return Err(HashSet::from([
format!(
"Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
@ -949,7 +941,7 @@ impl<'a> Inferencer<'a> {
})?;
// `ndims` can be deduced statically from the inferred Tuple type.
ndims = Some(tuple_element_types.len() as u64);
tuple_element_types.len() as u64
}
TypeEnum::TObj { .. } => {
// Handle 3. An integer (generalized as [`TypeEnum::TObj`])
@ -958,7 +950,7 @@ impl<'a> Inferencer<'a> {
self.unify(self.primitives.int32, shape_ty, &shape_location)?;
// Deduce `ndims`
ndims = Some(1);
1
}
_ => {
// The user wrote an ill-typed `shape_expr`,
@ -972,9 +964,8 @@ impl<'a> Inferencer<'a> {
shape_location,
);
}
}
};
let ndims = ndims.unwrap_or_else(|| unreachable!("ndims should be initialized"));
Ok((ndims, shape))
}