1
0
forked from M-Labs/nac3

core/typecheck: fix np_array ndmin bug

This commit is contained in:
lyken 2024-08-13 12:50:04 +08:00
parent 7e3d87f841
commit 35a7cecc12
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8

View File

@ -1,3 +1,4 @@
use std::cmp::max;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::convert::{From, TryInto}; use std::convert::{From, TryInto};
use std::iter::once; use std::iter::once;
@ -1560,17 +1561,18 @@ impl<'a> Inferencer<'a> {
keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into())); keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap()); let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap());
let arg0_ndims = arraylike_get_ndims(self.unifier, arg0.custom.unwrap());
let ndims = if let Some(ndmin_kw) = ndmin_kw { let ndims = if let Some(ndmin_kw) = ndmin_kw {
match &ndmin_kw.node.value.node { match &ndmin_kw.node.value.node {
ExprKind::Constant { value, .. } => match value { ExprKind::Constant { value, .. } => match value {
ast::Constant::Int(value) => *value as u64, ast::Constant::Int(value) => max(*value as u64, arg0_ndims),
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])), _ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
}, },
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()), _ => arg0_ndims,
} }
} else { } else {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) arg0_ndims
}; };
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));