forked from M-Labs/nac3
core/toplevel/numpy: Split ndarray type var utilities
This commit is contained in:
parent
87bc34f7ec
commit
3a6c53d760
|
@ -5,7 +5,7 @@ use nac3core::{
|
|||
toplevel::{
|
||||
DefinitionId,
|
||||
helper::PRIMITIVE_DEF_IDS,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
TopLevelDef,
|
||||
},
|
||||
typecheck::{
|
||||
|
@ -665,7 +665,7 @@ impl InnerResolver {
|
|||
}
|
||||
}
|
||||
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty);
|
||||
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||
if len == 0 {
|
||||
assert!(matches!(
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
symbol_resolver::{StaticValue, SymbolResolver},
|
||||
toplevel::{
|
||||
helper::PRIMITIVE_DEF_IDS,
|
||||
numpy::unpack_ndarray_tvars,
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
TopLevelContext,
|
||||
TopLevelDef,
|
||||
},
|
||||
|
@ -451,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||
|
||||
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
let llvm_usize = generator.get_size_type(ctx);
|
||||
let (dtype, _) = unpack_ndarray_tvars(unifier, ty);
|
||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
||||
let element_type = get_llvm_type(
|
||||
ctx,
|
||||
module,
|
||||
|
|
|
@ -27,7 +27,7 @@ use crate::{
|
|||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
DefinitionId,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
@ -748,7 +748,7 @@ pub fn gen_ndarray_copy<'ctx>(
|
|||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty);
|
||||
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
||||
let this_arg = obj
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
|
|
|
@ -13,7 +13,7 @@ use crate::{
|
|||
toplevel::{
|
||||
DefinitionId,
|
||||
helper::PRIMITIVE_DEF_IDS,
|
||||
numpy::unpack_ndarray_tvars,
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
TopLevelDef,
|
||||
},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
|
@ -251,7 +251,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
|||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||
TypeEnum::TList { ty } => *ty,
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
|
|
@ -19,13 +19,30 @@ pub fn make_ndarray_ty(
|
|||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
) -> Type {
|
||||
let ndarray = primitives.ndarray;
|
||||
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
|
||||
}
|
||||
|
||||
/// Substitutes type variables in `ndarray`.
|
||||
///
|
||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
pub fn subst_ndarray_tvars(
|
||||
unifier: &mut Unifier,
|
||||
ndarray: Type,
|
||||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
) -> Type {
|
||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||
};
|
||||
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
if dtype.is_none() && ndims.is_none() {
|
||||
return ndarray
|
||||
}
|
||||
|
||||
let tvar_ids = params.iter()
|
||||
.map(|(obj_id, _)| *obj_id)
|
||||
.collect_vec();
|
||||
|
@ -42,12 +59,10 @@ pub fn make_ndarray_ty(
|
|||
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
||||
}
|
||||
|
||||
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
||||
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
||||
pub fn unpack_ndarray_tvars(
|
||||
fn unpack_ndarray_tvars(
|
||||
unifier: &mut Unifier,
|
||||
ndarray: Type,
|
||||
) -> (Type, Type) {
|
||||
) -> Vec<(u32, Type)> {
|
||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||
};
|
||||
|
@ -56,7 +71,33 @@ pub fn unpack_ndarray_tvars(
|
|||
|
||||
params.iter()
|
||||
.sorted_by_key(|(obj_id, _)| *obj_id)
|
||||
.map(|(_, ty)| *ty)
|
||||
.map(|(var_id, ty)| (*var_id, *ty))
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
||||
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
||||
/// respectively.
|
||||
pub fn unpack_ndarray_var_ids(
|
||||
unifier: &mut Unifier,
|
||||
ndarray: Type,
|
||||
) -> (u32, u32) {
|
||||
unpack_ndarray_tvars(unifier, ndarray)
|
||||
.into_iter()
|
||||
.map(|v| v.0)
|
||||
.collect_tuple()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
||||
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
||||
pub fn unpack_ndarray_var_tys(
|
||||
unifier: &mut Unifier,
|
||||
ndarray: Type,
|
||||
) -> (Type, Type) {
|
||||
unpack_ndarray_tvars(unifier, ndarray)
|
||||
.into_iter()
|
||||
.map(|v| v.1)
|
||||
.collect_tuple()
|
||||
.unwrap()
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||
toplevel::{
|
||||
helper::PRIMITIVE_DEF_IDS,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
TopLevelContext,
|
||||
},
|
||||
};
|
||||
|
@ -1344,7 +1344,7 @@ impl<'a> Inferencer<'a> {
|
|||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
||||
}
|
||||
|
@ -1357,7 +1357,7 @@ impl<'a> Inferencer<'a> {
|
|||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||
self.infer_subscript_ndarray(value, ty, ndims)
|
||||
}
|
||||
_ => {
|
||||
|
@ -1389,7 +1389,7 @@ impl<'a> Inferencer<'a> {
|
|||
Ok(ty)
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||
|
||||
let valid_index_tys = [
|
||||
self.primitives.int32,
|
||||
|
|
Loading…
Reference in New Issue