core: Implement helper for creation of generic ndarray

This commit is contained in:
David Mak 2023-12-21 14:43:19 +08:00
parent c395472094
commit afa7d9b100
1 changed files with 21 additions and 0 deletions

View File

@ -205,6 +205,27 @@ impl TypeEnum {
TypeEnum::TFunc { .. } => "TFunc", TypeEnum::TFunc { .. } => "TFunc",
} }
} }
/// Returns a [TypeEnum] representing a generic `ndarray` type.
///
/// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic.
/// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic.
#[must_use]
pub fn ndarray(
unifier: &mut Unifier,
dtype: Option<Type>,
ndims: Option<Type>,
primitives: &PrimitiveStore
) -> TypeEnum {
let dtype = dtype.unwrap_or_else(|| unifier.get_fresh_var(Some("T".into()), None).0);
let ndims = ndims
.unwrap_or_else(|| unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None).0);
TypeEnum::TNDArray {
ty: dtype,
ndims,
}
}
} }
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>; pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;