diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index e212ac8..fdef020 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; +use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; @@ -691,3 +692,35 @@ pub fn parse_parameter_default_value( ])) } } + +/// Obtains the element type of an array-like type. +pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { + match &*unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => + unpack_ndarray_var_tys(unifier, ty).0, + + TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), + _ => ty + } +} + +/// Obtains the number of dimensions of an array-like type. +pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { + match &*unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let ndims = unpack_ndarray_var_tys(unifier, ty).1; + let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { + panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) + }; + + if values.len() > 1 { + todo!("Getting num of dimensions for ndarray with more than one ndim bound is unimplemented") + } + + u64::try_from(values[0].clone()).unwrap() + } + + TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1, + _ => 0 + } +}