2024-02-27 13:39:05 +08:00
|
|
|
use itertools::Itertools;
|
2023-11-17 17:30:27 +08:00
|
|
|
use crate::{
|
2024-03-11 14:47:01 +08:00
|
|
|
toplevel::helper::PRIMITIVE_DEF_IDS,
|
2024-02-27 13:39:05 +08:00
|
|
|
typecheck::{
|
|
|
|
type_inferencer::PrimitiveStore,
|
2024-03-11 14:47:01 +08:00
|
|
|
typedef::{Type, TypeEnum, Unifier, VarMap},
|
2024-02-27 13:39:05 +08:00
|
|
|
},
|
2023-11-17 17:30:27 +08:00
|
|
|
};
|
|
|
|
|
2024-02-27 13:39:05 +08:00
|
|
|
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
|
|
|
///
|
|
|
|
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
|
|
|
/// specialized.
|
|
|
|
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
|
|
|
/// specialized.
|
|
|
|
pub fn make_ndarray_ty(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
primitives: &PrimitiveStore,
|
|
|
|
dtype: Option<Type>,
|
|
|
|
ndims: Option<Type>,
|
|
|
|
) -> Type {
|
|
|
|
let ndarray = primitives.ndarray;
|
|
|
|
|
|
|
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
|
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
|
|
};
|
|
|
|
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
|
|
|
|
|
|
|
let tvar_ids = params.iter()
|
|
|
|
.map(|(obj_id, _)| *obj_id)
|
|
|
|
.collect_vec();
|
|
|
|
debug_assert_eq!(tvar_ids.len(), 2);
|
|
|
|
|
2024-03-04 23:38:52 +08:00
|
|
|
let mut tvar_subst = VarMap::new();
|
2024-02-27 13:39:05 +08:00
|
|
|
if let Some(dtype) = dtype {
|
|
|
|
tvar_subst.insert(tvar_ids[0], dtype);
|
|
|
|
}
|
|
|
|
if let Some(ndims) = ndims {
|
|
|
|
tvar_subst.insert(tvar_ids[1], ndims);
|
|
|
|
}
|
|
|
|
|
|
|
|
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
|
|
|
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
|
|
|
pub fn unpack_ndarray_tvars(
|
|
|
|
unifier: &mut Unifier,
|
|
|
|
ndarray: Type,
|
|
|
|
) -> (Type, Type) {
|
|
|
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
|
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
|
|
};
|
|
|
|
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
|
|
|
debug_assert_eq!(params.len(), 2);
|
|
|
|
|
|
|
|
params.iter()
|
|
|
|
.sorted_by_key(|(obj_id, _)| *obj_id)
|
|
|
|
.map(|(_, ty)| *ty)
|
|
|
|
.collect_tuple()
|
|
|
|
.unwrap()
|
|
|
|
}
|