parent
5b1aa812ed
commit
95eab02506
@ -1586,6 +1586,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
fn infer_subscript_ndarray(
|
fn infer_subscript_ndarray(
|
||||||
&mut self,
|
&mut self,
|
||||||
value: &ast::Expr<Option<Type>>,
|
value: &ast::Expr<Option<Type>>,
|
||||||
|
slice: &ast::Expr<Option<Type>>,
|
||||||
dummy_tvar: Type,
|
dummy_tvar: Type,
|
||||||
ndims: Type,
|
ndims: Type,
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
@ -1604,48 +1605,66 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
let ndims = values
|
let ndims = values
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ndim| match *ndim {
|
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
|
||||||
SymbolValue::U64(v) => Ok(v),
|
.collect::<Result<Vec<_>, _>>()
|
||||||
SymbolValue::U32(v) => Ok(u64::from(v)),
|
.map_err(|val| {
|
||||||
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| {
|
|
||||||
HashSet::from([format!(
|
HashSet::from([format!(
|
||||||
"Expected non-negative literal for ndarray.ndims, got {v}"
|
"Expected non-negative literal for ndarray.ndims, got {}",
|
||||||
|
i128::try_from(val).unwrap()
|
||||||
)])
|
)])
|
||||||
}),
|
})?;
|
||||||
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| {
|
|
||||||
HashSet::from([format!(
|
|
||||||
"Expected non-negative literal for ndarray.ndims, got {v}"
|
|
||||||
)])
|
|
||||||
}),
|
|
||||||
_ => unreachable!(),
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
|
||||||
|
|
||||||
assert!(!ndims.is_empty());
|
assert!(!ndims.is_empty());
|
||||||
|
|
||||||
if ndims.len() == 1 && ndims[0] == 1 {
|
// The number of dimensions subscripted by the index expression.
|
||||||
// ndarray[T, Literal[1]] - Index always returns an object of type T
|
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
|
||||||
|
// dimension will remove a dimension.
|
||||||
|
let subscripted_dims = match &slice.node {
|
||||||
|
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
|
||||||
|
if let ExprKind::Slice { .. } = &value_subexpr.node {
|
||||||
|
acc
|
||||||
|
} else {
|
||||||
|
acc + 1
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
|
ExprKind::Slice { .. } => 0,
|
||||||
|
_ => 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
|
||||||
|
// ndarray[T, Literal[1]] - Non-Slice index always returns an object of type T
|
||||||
|
|
||||||
assert_ne!(ndims[0], 0);
|
assert_ne!(ndims[0], 0);
|
||||||
|
|
||||||
Ok(dummy_tvar)
|
Ok(dummy_tvar)
|
||||||
} else {
|
} else {
|
||||||
// ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]]
|
// Otherwise - Index returns an object of type ndarray[T, Literal[N - subscripted_dims]]
|
||||||
|
|
||||||
if ndims.iter().any(|v| *v == 0) {
|
// Disallow subscripting if any Literal value will subscript on an element
|
||||||
|
let new_ndims = ndims
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| {
|
||||||
|
let v = i128::from(v) - i128::from(subscripted_dims);
|
||||||
|
u64::try_from(v)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|_| {
|
||||||
|
HashSet::from([format!(
|
||||||
|
"Cannot subscript {} by {subscripted_dims} dimensions",
|
||||||
|
self.unifier.stringify(value.custom.unwrap()),
|
||||||
|
)])
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if new_ndims.iter().any(|v| *v == 0) {
|
||||||
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
|
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndims_min_one_ty = self.unifier.get_fresh_literal(
|
let ndims_ty = self
|
||||||
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
.unifier
|
||||||
None,
|
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
|
||||||
);
|
let subscripted_ty =
|
||||||
let subscripted_ty = make_ndarray_ty(
|
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty));
|
||||||
self.unifier,
|
|
||||||
self.primitives,
|
|
||||||
Some(dummy_tvar),
|
|
||||||
Some(ndims_min_one_ty),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(subscripted_ty)
|
Ok(subscripted_ty)
|
||||||
}
|
}
|
||||||
@ -1682,7 +1701,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (_, ndims) =
|
let (_, ndims) =
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||||
self.infer_subscript_ndarray(value, ty, ndims)
|
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// the index is a constant, so value can be a sequence.
|
// the index is a constant, so value can be a sequence.
|
||||||
@ -1725,10 +1744,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||||
let ndarray_ty =
|
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
|
||||||
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
|
|
||||||
Ok(ndarray_ty)
|
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
@ -1763,7 +1779,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
.get_fresh_var_with_range(valid_index_tys.as_slice(), None, None)
|
.get_fresh_var_with_range(valid_index_tys.as_slice(), None, None)
|
||||||
.ty;
|
.ty;
|
||||||
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
|
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
|
||||||
self.infer_subscript_ndarray(value, ty, ndims)
|
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user