core: Fix type inference for tuple-index into ndarray

Fixes #420.
This commit is contained in:
David Mak 2024-06-20 11:45:30 +08:00 committed by sb10q
parent e36af3b0a3
commit 635c944c90
1 changed files with 52 additions and 36 deletions

View File

@ -1586,6 +1586,7 @@ impl<'a> Inferencer<'a> {
fn infer_subscript_ndarray( fn infer_subscript_ndarray(
&mut self, &mut self,
value: &ast::Expr<Option<Type>>, value: &ast::Expr<Option<Type>>,
slice: &ast::Expr<Option<Type>>,
dummy_tvar: Type, dummy_tvar: Type,
ndims: Type, ndims: Type,
) -> InferenceResult { ) -> InferenceResult {
@ -1604,48 +1605,66 @@ impl<'a> Inferencer<'a> {
let ndims = values let ndims = values
.iter() .iter()
.map(|ndim| match *ndim { .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
SymbolValue::U64(v) => Ok(v), .collect::<Result<Vec<_>, _>>()
SymbolValue::U32(v) => Ok(u64::from(v)), .map_err(|val| {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| {
HashSet::from([format!( HashSet::from([format!(
"Expected non-negative literal for ndarray.ndims, got {v}" "Expected non-negative literal for ndarray.ndims, got {}",
i128::try_from(val).unwrap()
)]) )])
}), })?;
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| {
HashSet::from([format!(
"Expected non-negative literal for ndarray.ndims, got {v}"
)])
}),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()?;
assert!(!ndims.is_empty()); assert!(!ndims.is_empty());
if ndims.len() == 1 && ndims[0] == 1 { // The number of dimensions subscripted by the index expression.
// ndarray[T, Literal[1]] - Index always returns an object of type T // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
// dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
// ndarray[T, Literal[1]] - Non-Slice index always returns an object of type T
assert_ne!(ndims[0], 0); assert_ne!(ndims[0], 0);
Ok(dummy_tvar) Ok(dummy_tvar)
} else { } else {
// ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]] // Otherwise - Index returns an object of type ndarray[T, Literal[N - subscripted_dims]]
if ndims.iter().any(|v| *v == 0) { // Disallow subscripting if any Literal value will subscript on an element
let new_ndims = ndims
.into_iter()
.map(|v| {
let v = i128::from(v) - i128::from(subscripted_dims);
u64::try_from(v)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|_| {
HashSet::from([format!(
"Cannot subscript {} by {subscripted_dims} dimensions",
self.unifier.stringify(value.custom.unwrap()),
)])
})?;
if new_ndims.iter().any(|v| *v == 0) {
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented") unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
} }
let ndims_min_one_ty = self.unifier.get_fresh_literal( let ndims_ty = self
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(), .unifier
None, .get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
); let subscripted_ty =
let subscripted_ty = make_ndarray_ty( make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty));
self.unifier,
self.primitives,
Some(dummy_tvar),
Some(ndims_min_one_ty),
);
Ok(subscripted_ty) Ok(subscripted_ty)
} }
@ -1682,7 +1701,7 @@ impl<'a> Inferencer<'a> {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.infer_subscript_ndarray(value, ty, ndims) self.infer_subscript_ndarray(value, slice, ty, ndims)
} }
_ => { _ => {
// the index is a constant, so value can be a sequence. // the index is a constant, so value can be a sequence.
@ -1725,10 +1744,7 @@ impl<'a> Inferencer<'a> {
} }
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndarray_ty = self.infer_subscript_ndarray(value, slice, ty, ndims)
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
Ok(ndarray_ty)
} }
_ => { _ => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
@ -1763,7 +1779,7 @@ impl<'a> Inferencer<'a> {
.get_fresh_var_with_range(valid_index_tys.as_slice(), None, None) .get_fresh_var_with_range(valid_index_tys.as_slice(), None, None)
.ty; .ty;
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?; self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims) self.infer_subscript_ndarray(value, slice, ty, ndims)
} }
_ => unreachable!(), _ => unreachable!(),
} }