diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 9ac503a17..632e6d06b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::collections::{HashMap, HashSet}; use std::convert::{From, TryInto}; use std::iter::once; @@ -1560,17 +1561,18 @@ impl<'a> Inferencer<'a> { keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into())); let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap()); + let arg0_ndims = arraylike_get_ndims(self.unifier, arg0.custom.unwrap()); let ndims = if let Some(ndmin_kw) = ndmin_kw { match &ndmin_kw.node.value.node { ExprKind::Constant { value, .. } => match value { - ast::Constant::Int(value) => *value as u64, + ast::Constant::Int(value) => max(*value as u64, arg0_ndims), _ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])), }, - _ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()), + _ => arg0_ndims, } } else { - arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) + arg0_ndims }; let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));