core: refactor fold_numpy_function_call_shape_argument

This commit is contained in:
lyken 2024-06-27 10:06:46 +08:00
parent 0a9b7aa16b
commit fb3588b20f
1 changed files with 13 additions and 24 deletions

View File

@ -884,18 +884,6 @@ impl<'a> Inferencer<'a> {
let shape_expr_name = shape_expr.node.name();
let shape_location = shape_expr.location;
// The deduced `ndims` of `shape`. To be determined.
//
// Also the rust compiler is not smart enough to tell `ndims` must be initialized when used
// in this block of code. So we do `Option` + `unwrap`.
let mut ndims: Option<u64> = None;
// Special handling for (1. A python `List` (all `int32s`)).
// Read the doc above this function to see what is going on here.
if let ExprKind::List { elts, .. } = &shape_expr.node {
ndims = Some(elts.len() as u64);
}
// Fold `shape`
let shape = self.fold_expr(shape_expr)?;
let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
@ -905,7 +893,7 @@ impl<'a> Inferencer<'a> {
//
// Here, we also take the opportunity to deduce `ndims` statically for 2. and 3.
let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
match shape_ty_enum {
let ndims = match shape_ty_enum {
TypeEnum::TList { ty } => {
// Handle 1. A list of int32s
@ -917,13 +905,15 @@ impl<'a> Inferencer<'a> {
.to_string()])
})?;
// Special handling: when nac3core reaches this line of code, two things can happen:
// Case 1: `ndims` is `Some`, this means the previous "Special handling for 1." ran,
// and the user wrote a List literal as the input argument.
// Case 2: `ndims` is `None`, this means the user is passing an expression of type `List`,
// but it is done so indirectly (like putting a variable referencing a `List`)
// rather than writing a List literal. We need to report an error.
if ndims.is_none() {
// Special handling for (1. A python `List` (all `int32s`)).
// Read the doc above this function to see what is going on here.
if let ExprKind::List { elts, .. } = &shape.node {
// The user wrote a List literal as the input argument
elts.len() as u64
} else {
// This means the user is passing an expression of type `List`,
// but it is done so indirectly (like putting a variable referencing a `List`)
// rather than writing a List literal. We need to report an error.
return Err(HashSet::from([
format!(
"Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
@ -951,7 +941,7 @@ impl<'a> Inferencer<'a> {
})?;
// `ndims` can be deduced statically from the inferred Tuple type.
ndims = Some(tuple_element_types.len() as u64);
tuple_element_types.len() as u64
}
TypeEnum::TObj { .. } => {
// Handle 3. An integer (generalized as [`TypeEnum::TObj`])
@ -960,7 +950,7 @@ impl<'a> Inferencer<'a> {
self.unify(self.primitives.int32, shape_ty, &shape_location)?;
// Deduce `ndims`
ndims = Some(1);
1
}
_ => {
// The user wrote an ill-typed `shape_expr`,
@ -974,9 +964,8 @@ impl<'a> Inferencer<'a> {
shape_location,
);
}
}
};
let ndims = ndims.unwrap_or_else(|| unreachable!("ndims should be initialized"));
Ok((ndims, shape))
}