forked from M-Labs/nac3
core/typecheck: Basic ndarray indexing support
This commit is contained in:
parent
72cb693e2e
commit
c74b7992f6
|
@ -1156,6 +1156,7 @@ impl<'a> Inferencer<'a> {
|
||||||
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||||
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
||||||
}
|
}
|
||||||
|
// xxx: Support TNDArray.
|
||||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
||||||
Ok(list)
|
Ok(list)
|
||||||
|
@ -1174,17 +1175,59 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
return report_error(
|
TypeEnum::TTuple { .. } => return report_error(
|
||||||
"Tuple index must be a constant (KernelInvariant is also not supported)",
|
"Tuple index must be a constant (KernelInvariant is also not supported)",
|
||||||
slice.location,
|
slice.location,
|
||||||
);
|
),
|
||||||
|
TypeEnum::TNDArray { ty: elem_ty, num_dims } => {
|
||||||
|
let num_idxs = if let TypeEnum::TTuple { ty: idx_tys } =
|
||||||
|
&*self.unifier.get_ty(slice.custom.unwrap())
|
||||||
|
{
|
||||||
|
for idx_ty in idx_tys.iter() {
|
||||||
|
self.constrain(*idx_ty, self.primitives.int32, &slice.location)?;
|
||||||
|
}
|
||||||
|
idx_tys.len()
|
||||||
|
} else {
|
||||||
|
// xxx: Could lead to suboptimal error message, as higher-dimensional indexing is not mentioned?!
|
||||||
|
self.constrain(
|
||||||
|
slice.custom.unwrap(),
|
||||||
|
self.primitives.int32,
|
||||||
|
&slice.location,
|
||||||
|
)?;
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
|
if *num_dims < num_idxs {
|
||||||
|
report_error(
|
||||||
|
&format!(
|
||||||
|
"ndarray has dimension {}, but {} indices supplied",
|
||||||
|
num_dims, num_idxs
|
||||||
|
),
|
||||||
|
slice.location,
|
||||||
|
)
|
||||||
|
} else if *num_dims == num_idxs {
|
||||||
|
Ok(*elem_ty)
|
||||||
|
} else {
|
||||||
|
Ok(self.unifier.add_ty(TypeEnum::TNDArray {
|
||||||
|
ty: *elem_ty,
|
||||||
|
num_dims: *num_dims - num_idxs,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// the index is not a constant, so value can only be a list
|
||||||
|
// xxx: Or an ndarray now, so remove the constraint?
|
||||||
|
self.constrain(
|
||||||
|
slice.custom.unwrap(),
|
||||||
|
self.primitives.int32,
|
||||||
|
&slice.location,
|
||||||
|
)?;
|
||||||
|
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||||
|
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
||||||
|
Ok(ty)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// the index is not a constant, so value can only be a list
|
|
||||||
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
|
||||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
|
||||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
|
||||||
Ok(ty)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -516,9 +516,14 @@ impl TestEnvironment {
|
||||||
#[test_case(
|
#[test_case(
|
||||||
indoc! {"
|
indoc! {"
|
||||||
a = array([1, 2])
|
a = array([1, 2])
|
||||||
|
a0 = a[0]
|
||||||
b = array([[1, 2], [3, 4]])
|
b = array([[1, 2], [3, 4]])
|
||||||
|
# b0 = b[0]
|
||||||
|
b00 = b[0, 0]
|
||||||
|
c = 1
|
||||||
|
ac = a[c]
|
||||||
"},
|
"},
|
||||||
[("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]")].iter().cloned().collect(),
|
[("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]"), ("a0", "int32"), ("b00", "int32"), ("ac", "int32")].iter().cloned().collect(),
|
||||||
&[]
|
&[]
|
||||||
; "array test")]
|
; "array test")]
|
||||||
#[test_case(indoc! {"
|
#[test_case(indoc! {"
|
||||||
|
|
|
@ -139,7 +139,10 @@ pub enum TypeEnum {
|
||||||
},
|
},
|
||||||
TNDArray {
|
TNDArray {
|
||||||
ty: Type,
|
ty: Type,
|
||||||
num_dims: u8,
|
|
||||||
|
// We could introduce a more sensible limit for the number of dimensions
|
||||||
|
// and make this e.g. u8; usize for now to avoid some casts.
|
||||||
|
num_dims: usize,
|
||||||
},
|
},
|
||||||
TObj {
|
TObj {
|
||||||
obj_id: DefinitionId,
|
obj_id: DefinitionId,
|
||||||
|
@ -655,6 +658,24 @@ impl Unifier {
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
self.set_a_to_b(a, x);
|
self.set_a_to_b(a, x);
|
||||||
}
|
}
|
||||||
|
(TVar { fields: Some(fields), range, .. }, TNDArray { ty, num_dims }) => {
|
||||||
|
for (k, v) in fields.iter() {
|
||||||
|
match *k {
|
||||||
|
RecordKey::Int(_) => {
|
||||||
|
if *num_dims > 1 {
|
||||||
|
unreachable!("xxx implement unification for scalar indexing of multidimensional array");
|
||||||
|
}
|
||||||
|
self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?
|
||||||
|
}
|
||||||
|
RecordKey::Str(_) => {
|
||||||
|
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let x = self.check_var_compatibility(b, range)?.unwrap_or(b);
|
||||||
|
self.unify_impl(x, b, false)?;
|
||||||
|
self.set_a_to_b(a, x);
|
||||||
|
}
|
||||||
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||||
if ty1.len() != ty2.len() {
|
if ty1.len() != ty2.len() {
|
||||||
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
|
||||||
|
|
Loading…
Reference in New Issue