diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 251b91e..1925174 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1586,6 +1586,7 @@ impl<'a> Inferencer<'a> { fn infer_subscript_ndarray( &mut self, value: &ast::Expr>, + slice: &ast::Expr>, dummy_tvar: Type, ndims: Type, ) -> InferenceResult { @@ -1604,48 +1605,66 @@ impl<'a> Inferencer<'a> { let ndims = values .iter() - .map(|ndim| match *ndim { - SymbolValue::U64(v) => Ok(v), - SymbolValue::U32(v) => Ok(u64::from(v)), - SymbolValue::I32(v) => u64::try_from(v).map_err(|_| { - HashSet::from([format!( - "Expected non-negative literal for ndarray.ndims, got {v}" - )]) - }), - SymbolValue::I64(v) => u64::try_from(v).map_err(|_| { - HashSet::from([format!( - "Expected non-negative literal for ndarray.ndims, got {v}" - )]) - }), - _ => unreachable!(), - }) - .collect::, _>>()?; + .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) + .collect::, _>>() + .map_err(|val| { + HashSet::from([format!( + "Expected non-negative literal for ndarray.ndims, got {}", + i128::try_from(val).unwrap() + )]) + })?; assert!(!ndims.is_empty()); - if ndims.len() == 1 && ndims[0] == 1 { - // ndarray[T, Literal[1]] - Index always returns an object of type T + // The number of dimensions subscripted by the index expression. + // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a + // dimension will remove a dimension. + let subscripted_dims = match &slice.node { + ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { + if let ExprKind::Slice { .. } = &value_subexpr.node { + acc + } else { + acc + 1 + } + }), + + ExprKind::Slice { .. } => 0, + _ => 1, + }; + + if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { + // ndarray[T, Literal[1]] - Non-Slice index always returns an object of type T assert_ne!(ndims[0], 0); Ok(dummy_tvar) } else { - // ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]] + // Otherwise - Index returns an object of type ndarray[T, Literal[N - subscripted_dims]] - if ndims.iter().any(|v| *v == 0) { + // Disallow subscripting if any Literal value will subscript on an element + let new_ndims = ndims + .into_iter() + .map(|v| { + let v = i128::from(v) - i128::from(subscripted_dims); + u64::try_from(v) + }) + .collect::, _>>() + .map_err(|_| { + HashSet::from([format!( + "Cannot subscript {} by {subscripted_dims} dimensions", + self.unifier.stringify(value.custom.unwrap()), + )]) + })?; + + if new_ndims.iter().any(|v| *v == 0) { unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented") } - let ndims_min_one_ty = self.unifier.get_fresh_literal( - ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(), - None, - ); - let subscripted_ty = make_ndarray_ty( - self.unifier, - self.primitives, - Some(dummy_tvar), - Some(ndims_min_one_ty), - ); + let ndims_ty = self + .unifier + .get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None); + let subscripted_ty = + make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty)); Ok(subscripted_ty) } @@ -1682,7 +1701,7 @@ impl<'a> Inferencer<'a> { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); - self.infer_subscript_ndarray(value, ty, ndims) + self.infer_subscript_ndarray(value, slice, ty, ndims) } _ => { // the index is a constant, so value can be a sequence. @@ -1725,10 +1744,7 @@ impl<'a> Inferencer<'a> { } let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); - let ndarray_ty = - make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); - self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?; - Ok(ndarray_ty) + self.infer_subscript_ndarray(value, slice, ty, ndims) } _ => { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { @@ -1763,7 +1779,7 @@ impl<'a> Inferencer<'a> { .get_fresh_var_with_range(valid_index_tys.as_slice(), None, None) .ty; self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?; - self.infer_subscript_ndarray(value, ty, ndims) + self.infer_subscript_ndarray(value, slice, ty, ndims) } _ => unreachable!(), }