2024-10-03 12:37:56 +08:00
|
|
|
use itertools::Itertools;
|
|
|
|
|
2024-10-17 15:57:33 +08:00
|
|
|
use super::helper::PrimDef;
|
|
|
|
use crate::typecheck::{
|
|
|
|
type_inferencer::PrimitiveStore,
|
|
|
|
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
2023-11-17 17:30:27 +08:00
|
|
|
};
|
|
|
|
|
2024-02-27 13:39:05 +08:00
|
|
|
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
2024-06-12 14:45:03 +08:00
|
|
|
///
|
2024-02-27 13:39:05 +08:00
|
|
|
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
2024-08-21 11:10:52 +08:00
|
|
|
/// specialized.
|
2024-02-27 13:39:05 +08:00
|
|
|
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
2024-08-21 11:10:52 +08:00
|
|
|
/// specialized.
|
2024-02-27 13:39:05 +08:00
|
|
|
pub fn make_ndarray_ty(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
primitives: &PrimitiveStore,
|
|
|
|
dtype: Option<Type>,
|
|
|
|
ndims: Option<Type>,
|
|
|
|
) -> Type {
|
2024-03-26 19:14:56 +08:00
|
|
|
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
|
|
|
|
}
|
2024-02-27 13:39:05 +08:00
|
|
|
|
2024-03-26 19:14:56 +08:00
|
|
|
/// Substitutes type variables in `ndarray`.
|
|
|
|
///
|
|
|
|
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
2024-08-21 11:10:52 +08:00
|
|
|
/// specialized.
|
2024-03-26 19:14:56 +08:00
|
|
|
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
2024-08-21 11:10:52 +08:00
|
|
|
/// specialized.
|
2024-03-26 19:14:56 +08:00
|
|
|
pub fn subst_ndarray_tvars(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
ndarray: Type,
|
|
|
|
dtype: Option<Type>,
|
|
|
|
ndims: Option<Type>,
|
|
|
|
) -> Type {
|
2024-02-27 13:39:05 +08:00
|
|
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
|
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
|
|
};
|
2024-06-12 15:01:01 +08:00
|
|
|
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
2024-02-27 13:39:05 +08:00
|
|
|
|
2024-03-26 19:14:56 +08:00
|
|
|
if dtype.is_none() && ndims.is_none() {
|
2024-06-12 14:45:03 +08:00
|
|
|
return ndarray;
|
2024-03-26 19:14:56 +08:00
|
|
|
}
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
|
2024-02-27 13:39:05 +08:00
|
|
|
debug_assert_eq!(tvar_ids.len(), 2);
|
|
|
|
|
2024-03-04 23:38:52 +08:00
|
|
|
let mut tvar_subst = VarMap::new();
|
2024-02-27 13:39:05 +08:00
|
|
|
if let Some(dtype) = dtype {
|
|
|
|
tvar_subst.insert(tvar_ids[0], dtype);
|
|
|
|
}
|
|
|
|
if let Some(ndims) = ndims {
|
|
|
|
tvar_subst.insert(tvar_ids[1], ndims);
|
|
|
|
}
|
|
|
|
|
|
|
|
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-06-13 13:28:39 +08:00
|
|
|
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
|
2024-02-27 13:39:05 +08:00
|
|
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
|
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
|
|
};
|
2024-06-12 15:01:01 +08:00
|
|
|
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
2024-02-27 13:39:05 +08:00
|
|
|
debug_assert_eq!(params.len(), 2);
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
params
|
|
|
|
.iter()
|
2024-02-27 13:39:05 +08:00
|
|
|
.sorted_by_key(|(obj_id, _)| *obj_id)
|
2024-03-26 19:14:56 +08:00
|
|
|
.map(|(var_id, ty)| (*var_id, *ty))
|
|
|
|
.collect_vec()
|
|
|
|
}
|
|
|
|
|
2024-06-12 14:45:03 +08:00
|
|
|
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
|
|
|
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
2024-03-26 19:14:56 +08:00
|
|
|
/// respectively.
|
2024-06-13 13:28:39 +08:00
|
|
|
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
|
2024-06-12 14:45:03 +08:00
|
|
|
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
|
2024-03-26 19:14:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
|
|
|
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
2024-06-12 14:45:03 +08:00
|
|
|
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
|
|
|
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
2024-02-27 13:39:05 +08:00
|
|
|
}
|