1
0
forked from M-Labs/nac3
nac3/nac3core/src/toplevel/numpy.rs

85 lines
3.0 KiB
Rust
Raw Normal View History

use itertools::Itertools;
use super::helper::PrimDef;
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
};
2024-02-27 13:39:05 +08:00
/// Creates a `ndarray` [`Type`] with the given type arguments.
2024-06-12 14:45:03 +08:00
///
2024-02-27 13:39:05 +08:00
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
2024-08-21 11:10:52 +08:00
/// specialized.
2024-02-27 13:39:05 +08:00
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
2024-08-21 11:10:52 +08:00
/// specialized.
2024-02-27 13:39:05 +08:00
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
2024-02-27 13:39:05 +08:00
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
2024-08-21 11:10:52 +08:00
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
2024-08-21 11:10:52 +08:00
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
2024-02-27 13:39:05 +08:00
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
2024-02-27 13:39:05 +08:00
if dtype.is_none() && ndims.is_none() {
2024-06-12 14:45:03 +08:00
return ndarray;
}
2024-06-12 14:45:03 +08:00
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
2024-02-27 13:39:05 +08:00
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
2024-02-27 13:39:05 +08:00
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
2024-02-27 13:39:05 +08:00
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
2024-02-27 13:39:05 +08:00
debug_assert_eq!(params.len(), 2);
2024-06-12 14:45:03 +08:00
params
.iter()
2024-02-27 13:39:05 +08:00
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
2024-06-12 14:45:03 +08:00
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
2024-06-12 14:45:03 +08:00
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
2024-06-12 14:45:03 +08:00
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
2024-02-27 13:39:05 +08:00
}