diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a6b72bba..f1f2a5a2 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -13,7 +13,7 @@ use crate::{ TopLevelContext, }, }; -use itertools::izip; +use itertools::{Itertools, izip}; use nac3parser::ast::{ self, fold::{self, Fold}, @@ -59,6 +59,16 @@ pub struct PrimitiveStore { } impl PrimitiveStore { + /// Returns a [`Type`] representing a signed representation of `size_t`. + #[must_use] + pub fn isize(&self) -> Type { + match self.size_t { + 32 => self.int32, + 64 => self.int64, + _ => unreachable!(), + } + } + /// Returns a [Type] representing `size_t`. #[must_use] pub fn usize(&self) -> Type { @@ -1381,7 +1391,16 @@ impl<'a> Inferencer<'a> { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); - self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?; + let valid_index_tys = [ + self.primitives.int32, + self.primitives.isize(), + ].into_iter().unique().collect_vec(); + let valid_index_ty = self.unifier.get_fresh_var_with_range( + valid_index_tys.as_slice(), + None, + None, + ).0; + self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?; self.infer_subscript_ndarray(value, ty, ndims) } _ => unreachable!(),