2024-02-27 13:39:05 +08:00
|
|
|
use itertools::Itertools;
|
2023-11-17 17:30:27 +08:00
|
|
|
use crate::{
|
2024-03-11 14:47:01 +08:00
|
|
|
toplevel::helper::PRIMITIVE_DEF_IDS,
|
2024-02-27 13:39:05 +08:00
|
|
|
typecheck::{
|
|
|
|
type_inferencer::PrimitiveStore,
|
2024-03-11 14:47:01 +08:00
|
|
|
typedef::{Type, TypeEnum, Unifier, VarMap},
|
2024-02-27 13:39:05 +08:00
|
|
|
},
|
2023-11-17 17:30:27 +08:00
|
|
|
};
|
|
|
|
|
2024-02-27 13:39:05 +08:00
|
|
|
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
|
|
|
///
|
|
|
|
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
|
|
|
/// specialized.
|
|
|
|
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
|
|
|
/// specialized.
|
|
|
|
pub fn make_ndarray_ty(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
primitives: &PrimitiveStore,
|
|
|
|
dtype: Option<Type>,
|
|
|
|
ndims: Option<Type>,
|
|
|
|
) -> Type {
|
2024-03-26 19:14:56 +08:00
|
|
|
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
|
|
|
|
}
|
2024-02-27 13:39:05 +08:00
|
|
|
|
2024-03-26 19:14:56 +08:00
|
|
|
/// Substitutes type variables in `ndarray`.
|
|
|
|
///
|
|
|
|
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
|
|
|
/// specialized.
|
|
|
|
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
|
|
|
/// specialized.
|
|
|
|
pub fn subst_ndarray_tvars(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
ndarray: Type,
|
|
|
|
dtype: Option<Type>,
|
|
|
|
ndims: Option<Type>,
|
|
|
|
) -> Type {
|
2024-02-27 13:39:05 +08:00
|
|
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
|
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
|
|
};
|
|
|
|
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
|
|
|
|
2024-03-26 19:14:56 +08:00
|
|
|
if dtype.is_none() && ndims.is_none() {
|
|
|
|
return ndarray
|
|
|
|
}
|
|
|
|
|
2024-02-27 13:39:05 +08:00
|
|
|
let tvar_ids = params.iter()
|
|
|
|
.map(|(obj_id, _)| *obj_id)
|
|
|
|
.collect_vec();
|
|
|
|
debug_assert_eq!(tvar_ids.len(), 2);
|
|
|
|
|
2024-03-04 23:38:52 +08:00
|
|
|
let mut tvar_subst = VarMap::new();
|
2024-02-27 13:39:05 +08:00
|
|
|
if let Some(dtype) = dtype {
|
|
|
|
tvar_subst.insert(tvar_ids[0], dtype);
|
|
|
|
}
|
|
|
|
if let Some(ndims) = ndims {
|
|
|
|
tvar_subst.insert(tvar_ids[1], ndims);
|
|
|
|
}
|
|
|
|
|
|
|
|
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
|
|
|
}
|
|
|
|
|
2024-03-26 19:14:56 +08:00
|
|
|
fn unpack_ndarray_tvars(
|
2024-02-27 13:39:05 +08:00
|
|
|
unifier: &mut Unifier,
|
|
|
|
ndarray: Type,
|
2024-03-26 19:14:56 +08:00
|
|
|
) -> Vec<(u32, Type)> {
|
2024-02-27 13:39:05 +08:00
|
|
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
|
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
|
|
};
|
|
|
|
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
|
|
|
debug_assert_eq!(params.len(), 2);
|
|
|
|
|
|
|
|
params.iter()
|
|
|
|
.sorted_by_key(|(obj_id, _)| *obj_id)
|
2024-03-26 19:14:56 +08:00
|
|
|
.map(|(var_id, ty)| (*var_id, *ty))
|
|
|
|
.collect_vec()
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
|
|
|
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
|
|
|
/// respectively.
|
|
|
|
pub fn unpack_ndarray_var_ids(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
ndarray: Type,
|
|
|
|
) -> (u32, u32) {
|
|
|
|
unpack_ndarray_tvars(unifier, ndarray)
|
|
|
|
.into_iter()
|
|
|
|
.map(|v| v.0)
|
|
|
|
.collect_tuple()
|
|
|
|
.unwrap()
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
|
|
|
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
|
|
|
pub fn unpack_ndarray_var_tys(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
ndarray: Type,
|
|
|
|
) -> (Type, Type) {
|
|
|
|
unpack_ndarray_tvars(unifier, ndarray)
|
|
|
|
.into_iter()
|
|
|
|
.map(|v| v.1)
|
2024-02-27 13:39:05 +08:00
|
|
|
.collect_tuple()
|
|
|
|
.unwrap()
|
|
|
|
}
|