core/typecheck: Basic ndarray indexing support

This commit is contained in:
David Nadlinger 2022-04-22 21:55:27 +01:00
parent 72cb693e2e
commit c74b7992f6
3 changed files with 79 additions and 10 deletions

View File

@ -1156,6 +1156,7 @@ impl<'a> Inferencer<'a> {
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
} }
// xxx: Support TNDArray.
let list = self.unifier.add_ty(TypeEnum::TList { ty }); let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list, &value.location)?; self.constrain(value.custom.unwrap(), list, &value.location)?;
Ok(list) Ok(list)
@ -1174,20 +1175,62 @@ impl<'a> Inferencer<'a> {
Ok(ty) Ok(ty)
} }
_ => { _ => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { match &*self.unifier.get_ty(value.custom.unwrap()) {
return report_error( TypeEnum::TTuple { .. } => return report_error(
"Tuple index must be a constant (KernelInvariant is also not supported)", "Tuple index must be a constant (KernelInvariant is also not supported)",
slice.location, slice.location,
); ),
TypeEnum::TNDArray { ty: elem_ty, num_dims } => {
let num_idxs = if let TypeEnum::TTuple { ty: idx_tys } =
&*self.unifier.get_ty(slice.custom.unwrap())
{
for idx_ty in idx_tys.iter() {
self.constrain(*idx_ty, self.primitives.int32, &slice.location)?;
} }
idx_tys.len()
} else {
// xxx: Could lead to suboptimal error message, as higher-dimensional indexing is not mentioned?!
self.constrain(
slice.custom.unwrap(),
self.primitives.int32,
&slice.location,
)?;
1
};
if *num_dims < num_idxs {
report_error(
&format!(
"ndarray has dimension {}, but {} indices supplied",
num_dims, num_idxs
),
slice.location,
)
} else if *num_dims == num_idxs {
Ok(*elem_ty)
} else {
Ok(self.unifier.add_ty(TypeEnum::TNDArray {
ty: *elem_ty,
num_dims: *num_dims - num_idxs,
}))
}
}
_ => {
// the index is not a constant, so value can only be a list // the index is not a constant, so value can only be a list
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; // xxx: Or an ndarray now, so remove the constraint?
self.constrain(
slice.custom.unwrap(),
self.primitives.int32,
&slice.location,
)?;
let list = self.unifier.add_ty(TypeEnum::TList { ty }); let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list, &value.location)?; self.constrain(value.custom.unwrap(), list, &value.location)?;
Ok(ty) Ok(ty)
} }
} }
} }
}
}
fn infer_if_expr( fn infer_if_expr(
&mut self, &mut self,

View File

@ -516,9 +516,14 @@ impl TestEnvironment {
#[test_case( #[test_case(
indoc! {" indoc! {"
a = array([1, 2]) a = array([1, 2])
a0 = a[0]
b = array([[1, 2], [3, 4]]) b = array([[1, 2], [3, 4]])
# b0 = b[0]
b00 = b[0, 0]
c = 1
ac = a[c]
"}, "},
[("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]")].iter().cloned().collect(), [("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]"), ("a0", "int32"), ("b00", "int32"), ("ac", "int32")].iter().cloned().collect(),
&[] &[]
; "array test")] ; "array test")]
#[test_case(indoc! {" #[test_case(indoc! {"

View File

@ -139,7 +139,10 @@ pub enum TypeEnum {
}, },
TNDArray { TNDArray {
ty: Type, ty: Type,
num_dims: u8,
// We could introduce a more sensible limit for the number of dimensions
// and make this e.g. u8; usize for now to avoid some casts.
num_dims: usize,
}, },
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
@ -655,6 +658,24 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, .. }, TNDArray { ty, num_dims }) => {
for (k, v) in fields.iter() {
match *k {
RecordKey::Int(_) => {
if *num_dims > 1 {
unreachable!("xxx implement unification for scalar indexing of multidimensional array");
}
self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?
}
RecordKey::Str(_) => {
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
}
}
}
let x = self.check_var_compatibility(b, range)?.unwrap_or(b);
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() { if ty1.len() != ty2.len() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));