use itertools::Itertools; use super::helper::PrimDef; use crate::typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, }; /// Creates a `ndarray` [`Type`] with the given type arguments. /// /// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not /// specialized. /// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not /// specialized. pub fn make_ndarray_ty( unifier: &mut Unifier, primitives: &PrimitiveStore, dtype: Option, ndims: Option, ) -> Type { subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims) } /// Substitutes type variables in `ndarray`. /// /// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not /// specialized. /// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not /// specialized. pub fn subst_ndarray_tvars( unifier: &mut Unifier, ndarray: Type, dtype: Option, ndims: Option, ) -> Type { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); if dtype.is_none() && ndims.is_none() { return ndarray; } let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec(); debug_assert_eq!(tvar_ids.len(), 2); let mut tvar_subst = VarMap::new(); if let Some(dtype) = dtype { tvar_subst.insert(tvar_ids[0], dtype); } if let Some(ndims) = ndims { tvar_subst.insert(tvar_ids[1], ndims); } unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) } fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); debug_assert_eq!(params.len(), 2); params .iter() .sorted_by_key(|(obj_id, _)| *obj_id) .map(|(var_id, ty)| (*var_id, *ty)) .collect_vec() } /// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds /// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` /// respectively. pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) { unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap() } /// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to /// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) { unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap() }