forked from M-Labs/nac3
core: Implement type inference for indexing into ndarray
This commit is contained in:
parent
976a9512c1
commit
0d5c53e60c
|
@ -1237,6 +1237,67 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(boolean)
|
Ok(boolean)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Infers the type of a subscript expression on an `ndarray`.
|
||||||
|
fn infer_subscript_ndarray(
|
||||||
|
&mut self,
|
||||||
|
value: &ast::Expr<Option<Type>>,
|
||||||
|
dummy_tvar: Type,
|
||||||
|
ndims: &Type,
|
||||||
|
) -> InferenceResult {
|
||||||
|
debug_assert!(matches!(
|
||||||
|
&*self.unifier.get_ty_immutable(dummy_tvar),
|
||||||
|
TypeEnum::TVar { is_const_generic: false, .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims: *ndims });
|
||||||
|
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
||||||
|
|
||||||
|
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(*ndims) else {
|
||||||
|
panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(*ndims))
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndims = values.iter()
|
||||||
|
.map(|ndim| match *ndim {
|
||||||
|
SymbolValue::U64(v) => Ok(v),
|
||||||
|
SymbolValue::U32(v) => Ok(v as u64),
|
||||||
|
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
||||||
|
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
|
||||||
|
])),
|
||||||
|
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
||||||
|
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
|
||||||
|
])),
|
||||||
|
_ => unreachable!(),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
assert!(!ndims.is_empty());
|
||||||
|
|
||||||
|
if ndims.len() == 1 && ndims[0] == 1 {
|
||||||
|
// ndarray[T, Literal[1]] - Index always returns an object of type T
|
||||||
|
|
||||||
|
assert_ne!(ndims[0], 0);
|
||||||
|
|
||||||
|
Ok(dummy_tvar)
|
||||||
|
} else {
|
||||||
|
// ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]]
|
||||||
|
|
||||||
|
if ndims.iter().any(|v| *v == 0) {
|
||||||
|
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndims_min_one_ty = self.unifier.get_fresh_literal(
|
||||||
|
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let subscripted_ty = self.unifier.add_ty(TypeEnum::TNDArray {
|
||||||
|
ty: dummy_tvar,
|
||||||
|
ndims: ndims_min_one_ty,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(subscripted_ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn infer_subscript(
|
fn infer_subscript(
|
||||||
&mut self,
|
&mut self,
|
||||||
value: &ast::Expr<Option<Type>>,
|
value: &ast::Expr<Option<Type>>,
|
||||||
|
@ -1258,33 +1319,41 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(list_like_ty)
|
Ok(list_like_ty)
|
||||||
}
|
}
|
||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
// the index is a constant, so value can be a sequence.
|
if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
let ind: Option<i32> = (*val).try_into().ok();
|
self.infer_subscript_ndarray(value, ty, ndims)
|
||||||
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
|
} else {
|
||||||
let map = once((
|
// the index is a constant, so value can be a sequence.
|
||||||
ind.into(),
|
let ind: Option<i32> = (*val).try_into().ok();
|
||||||
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
|
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
|
||||||
))
|
let map = once((
|
||||||
.collect();
|
ind.into(),
|
||||||
let seq = self.unifier.add_record(map);
|
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
|
||||||
self.constrain(value.custom.unwrap(), seq, &value.location)?;
|
))
|
||||||
Ok(ty)
|
.collect();
|
||||||
|
let seq = self.unifier.add_record(map);
|
||||||
|
self.constrain(value.custom.unwrap(), seq, &value.location)?;
|
||||||
|
Ok(ty)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap())
|
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
{
|
|
||||||
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
||||||
}
|
}
|
||||||
|
|
||||||
// the index is not a constant, so value can only be a list
|
// the index is not a constant, so value can only be a list-like structure
|
||||||
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
TypeEnum::TList { .. } => {
|
||||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
||||||
TypeEnum::TNDArray { .. } => todo!(),
|
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||||
|
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
||||||
|
Ok(ty)
|
||||||
|
}
|
||||||
|
TypeEnum::TNDArray { ndims, .. } => {
|
||||||
|
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
|
||||||
|
self.infer_subscript_ndarray(value, ty, ndims)
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
}
|
||||||
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
|
||||||
Ok(ty)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue