diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 1462ca6..8f12685 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1237,6 +1237,67 @@ impl<'a> Inferencer<'a> { Ok(boolean) } + /// Infers the type of a subscript expression on an `ndarray`. + fn infer_subscript_ndarray( + &mut self, + value: &ast::Expr>, + dummy_tvar: Type, + ndims: &Type, + ) -> InferenceResult { + debug_assert!(matches!( + &*self.unifier.get_ty_immutable(dummy_tvar), + TypeEnum::TVar { is_const_generic: false, .. } + )); + + let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims: *ndims }); + self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; + + let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(*ndims) else { + panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(*ndims)) + }; + + let ndims = values.iter() + .map(|ndim| match *ndim { + SymbolValue::U64(v) => Ok(v), + SymbolValue::U32(v) => Ok(v as u64), + SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([ + format!("Expected non-negative literal for TNDArray.ndims, got {v}"), + ])), + SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([ + format!("Expected non-negative literal for TNDArray.ndims, got {v}"), + ])), + _ => unreachable!(), + }) + .collect::, _>>()?; + + assert!(!ndims.is_empty()); + + if ndims.len() == 1 && ndims[0] == 1 { + // ndarray[T, Literal[1]] - Index always returns an object of type T + + assert_ne!(ndims[0], 0); + + Ok(dummy_tvar) + } else { + // ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]] + + if ndims.iter().any(|v| *v == 0) { + unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented") + } + + let ndims_min_one_ty = self.unifier.get_fresh_literal( + ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(), + None, + ); + let subscripted_ty = self.unifier.add_ty(TypeEnum::TNDArray { + ty: dummy_tvar, + ndims: ndims_min_one_ty, + }); + + Ok(subscripted_ty) + } + } + fn infer_subscript( &mut self, value: &ast::Expr>, @@ -1258,33 +1319,41 @@ impl<'a> Inferencer<'a> { Ok(list_like_ty) } ExprKind::Constant { value: ast::Constant::Int(val), .. } => { - // the index is a constant, so value can be a sequence. - let ind: Option = (*val).try_into().ok(); - let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?; - let map = once(( - ind.into(), - RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)), - )) - .collect(); - let seq = self.unifier.add_record(map); - self.constrain(value.custom.unwrap(), seq, &value.location)?; - Ok(ty) + if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) { + self.infer_subscript_ndarray(value, ty, ndims) + } else { + // the index is a constant, so value can be a sequence. + let ind: Option = (*val).try_into().ok(); + let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?; + let map = once(( + ind.into(), + RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)), + )) + .collect(); + let seq = self.unifier.add_record(map); + self.constrain(value.custom.unwrap(), seq, &value.location)?; + Ok(ty) + } } _ => { - if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) - { + if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location) } - // the index is not a constant, so value can only be a list - self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; - let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { - TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), - TypeEnum::TNDArray { .. } => todo!(), + // the index is not a constant, so value can only be a list-like structure + match &*self.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TList { .. } => { + self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; + let list = self.unifier.add_ty(TypeEnum::TList { ty }); + self.constrain(value.custom.unwrap(), list, &value.location)?; + Ok(ty) + } + TypeEnum::TNDArray { ndims, .. } => { + self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?; + self.infer_subscript_ndarray(value, ty, ndims) + } _ => unreachable!(), - }; - self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; - Ok(ty) + } } } }