core: refactor fold_numpy_function_call_shape_argument
This commit is contained in:
parent
0a9b7aa16b
commit
fb3588b20f
|
@ -884,18 +884,6 @@ impl<'a> Inferencer<'a> {
|
|||
let shape_expr_name = shape_expr.node.name();
|
||||
let shape_location = shape_expr.location;
|
||||
|
||||
// The deduced `ndims` of `shape`. To be determined.
|
||||
//
|
||||
// Also the rust compiler is not smart enough to tell `ndims` must be initialized when used
|
||||
// in this block of code. So we do `Option` + `unwrap`.
|
||||
let mut ndims: Option<u64> = None;
|
||||
|
||||
// Special handling for (1. A python `List` (all `int32s`)).
|
||||
// Read the doc above this function to see what is going on here.
|
||||
if let ExprKind::List { elts, .. } = &shape_expr.node {
|
||||
ndims = Some(elts.len() as u64);
|
||||
}
|
||||
|
||||
// Fold `shape`
|
||||
let shape = self.fold_expr(shape_expr)?;
|
||||
let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
|
||||
|
@ -905,7 +893,7 @@ impl<'a> Inferencer<'a> {
|
|||
//
|
||||
// Here, we also take the opportunity to deduce `ndims` statically for 2. and 3.
|
||||
let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
|
||||
match shape_ty_enum {
|
||||
let ndims = match shape_ty_enum {
|
||||
TypeEnum::TList { ty } => {
|
||||
// Handle 1. A list of int32s
|
||||
|
||||
|
@ -917,13 +905,15 @@ impl<'a> Inferencer<'a> {
|
|||
.to_string()])
|
||||
})?;
|
||||
|
||||
// Special handling: when nac3core reaches this line of code, two things can happen:
|
||||
// Case 1: `ndims` is `Some`, this means the previous "Special handling for 1." ran,
|
||||
// and the user wrote a List literal as the input argument.
|
||||
// Case 2: `ndims` is `None`, this means the user is passing an expression of type `List`,
|
||||
// Special handling for (1. A python `List` (all `int32s`)).
|
||||
// Read the doc above this function to see what is going on here.
|
||||
if let ExprKind::List { elts, .. } = &shape.node {
|
||||
// The user wrote a List literal as the input argument
|
||||
elts.len() as u64
|
||||
} else {
|
||||
// This means the user is passing an expression of type `List`,
|
||||
// but it is done so indirectly (like putting a variable referencing a `List`)
|
||||
// rather than writing a List literal. We need to report an error.
|
||||
if ndims.is_none() {
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
|
||||
|
@ -951,7 +941,7 @@ impl<'a> Inferencer<'a> {
|
|||
})?;
|
||||
|
||||
// `ndims` can be deduced statically from the inferred Tuple type.
|
||||
ndims = Some(tuple_element_types.len() as u64);
|
||||
tuple_element_types.len() as u64
|
||||
}
|
||||
TypeEnum::TObj { .. } => {
|
||||
// Handle 3. An integer (generalized as [`TypeEnum::TObj`])
|
||||
|
@ -960,7 +950,7 @@ impl<'a> Inferencer<'a> {
|
|||
self.unify(self.primitives.int32, shape_ty, &shape_location)?;
|
||||
|
||||
// Deduce `ndims`
|
||||
ndims = Some(1);
|
||||
1
|
||||
}
|
||||
_ => {
|
||||
// The user wrote an ill-typed `shape_expr`,
|
||||
|
@ -974,9 +964,8 @@ impl<'a> Inferencer<'a> {
|
|||
shape_location,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let ndims = ndims.unwrap_or_else(|| unreachable!("ndims should be initialized"));
|
||||
Ok((ndims, shape))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue