From c74b7992f65bd0dba820ac6a2d4786a22b8c3b70 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Fri, 22 Apr 2022 21:55:27 +0100 Subject: [PATCH] core/typecheck: Basic ndarray indexing support --- nac3core/src/typecheck/type_inferencer/mod.rs | 59 ++++++++++++++++--- .../src/typecheck/type_inferencer/test.rs | 7 ++- nac3core/src/typecheck/typedef/mod.rs | 23 +++++++- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 05fb9885f..db86ccdbf 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1156,6 +1156,7 @@ impl<'a> Inferencer<'a> { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; } + // xxx: Support TNDArray. let list = self.unifier.add_ty(TypeEnum::TList { ty }); self.constrain(value.custom.unwrap(), list, &value.location)?; Ok(list) @@ -1174,17 +1175,59 @@ impl<'a> Inferencer<'a> { Ok(ty) } _ => { - if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { - return report_error( + match &*self.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TTuple { .. } => return report_error( "Tuple index must be a constant (KernelInvariant is also not supported)", slice.location, - ); + ), + TypeEnum::TNDArray { ty: elem_ty, num_dims } => { + let num_idxs = if let TypeEnum::TTuple { ty: idx_tys } = + &*self.unifier.get_ty(slice.custom.unwrap()) + { + for idx_ty in idx_tys.iter() { + self.constrain(*idx_ty, self.primitives.int32, &slice.location)?; + } + idx_tys.len() + } else { + // xxx: Could lead to suboptimal error message, as higher-dimensional indexing is not mentioned?! + self.constrain( + slice.custom.unwrap(), + self.primitives.int32, + &slice.location, + )?; + 1 + }; + + if *num_dims < num_idxs { + report_error( + &format!( + "ndarray has dimension {}, but {} indices supplied", + num_dims, num_idxs + ), + slice.location, + ) + } else if *num_dims == num_idxs { + Ok(*elem_ty) + } else { + Ok(self.unifier.add_ty(TypeEnum::TNDArray { + ty: *elem_ty, + num_dims: *num_dims - num_idxs, + })) + } + } + _ => { + // the index is not a constant, so value can only be a list + // xxx: Or an ndarray now, so remove the constraint? + self.constrain( + slice.custom.unwrap(), + self.primitives.int32, + &slice.location, + )?; + let list = self.unifier.add_ty(TypeEnum::TList { ty }); + self.constrain(value.custom.unwrap(), list, &value.location)?; + Ok(ty) + } } - // the index is not a constant, so value can only be a list - self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; - let list = self.unifier.add_ty(TypeEnum::TList { ty }); - self.constrain(value.custom.unwrap(), list, &value.location)?; - Ok(ty) } } } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 3e58795c7..0ac4d99db 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -516,9 +516,14 @@ impl TestEnvironment { #[test_case( indoc! {" a = array([1, 2]) + a0 = a[0] b = array([[1, 2], [3, 4]]) + # b0 = b[0] + b00 = b[0, 0] + c = 1 + ac = a[c] "}, - [("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]")].iter().cloned().collect(), + [("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]"), ("a0", "int32"), ("b00", "int32"), ("ac", "int32")].iter().cloned().collect(), &[] ; "array test")] #[test_case(indoc! {" diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index fdd226aec..95c165623 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -139,7 +139,10 @@ pub enum TypeEnum { }, TNDArray { ty: Type, - num_dims: u8, + + // We could introduce a more sensible limit for the number of dimensions + // and make this e.g. u8; usize for now to avoid some casts. + num_dims: usize, }, TObj { obj_id: DefinitionId, @@ -655,6 +658,24 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } + (TVar { fields: Some(fields), range, .. }, TNDArray { ty, num_dims }) => { + for (k, v) in fields.iter() { + match *k { + RecordKey::Int(_) => { + if *num_dims > 1 { + unreachable!("xxx implement unification for scalar indexing of multidimensional array"); + } + self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))? + } + RecordKey::Str(_) => { + return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc)) + } + } + } + let x = self.check_var_compatibility(b, range)?.unwrap_or(b); + self.unify_impl(x, b, false)?; + self.set_a_to_b(a, x); + } (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { if ty1.len() != ty2.len() { return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));