forked from M-Labs/nac3
core/toplevel/numpy: Split ndarray type var utilities
This commit is contained in:
parent
87bc34f7ec
commit
3a6c53d760
@ -5,7 +5,7 @@ use nac3core::{
|
|||||||
toplevel::{
|
toplevel::{
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
helper::PRIMITIVE_DEF_IDS,
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
TopLevelDef,
|
TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
@ -665,7 +665,7 @@ impl InnerResolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty);
|
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
||||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
|
@ -2,7 +2,7 @@ use crate::{
|
|||||||
symbol_resolver::{StaticValue, SymbolResolver},
|
symbol_resolver::{StaticValue, SymbolResolver},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PRIMITIVE_DEF_IDS,
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
numpy::unpack_ndarray_tvars,
|
numpy::unpack_ndarray_var_tys,
|
||||||
TopLevelContext,
|
TopLevelContext,
|
||||||
TopLevelDef,
|
TopLevelDef,
|
||||||
},
|
},
|
||||||
@ -451,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
|
|
||||||
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
let llvm_usize = generator.get_size_type(ctx);
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
let (dtype, _) = unpack_ndarray_tvars(unifier, ty);
|
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
||||||
let element_type = get_llvm_type(
|
let element_type = get_llvm_type(
|
||||||
ctx,
|
ctx,
|
||||||
module,
|
module,
|
||||||
|
@ -27,7 +27,7 @@ use crate::{
|
|||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
},
|
},
|
||||||
typecheck::typedef::{FunSignature, Type},
|
typecheck::typedef::{FunSignature, Type},
|
||||||
};
|
};
|
||||||
@ -748,7 +748,7 @@ pub fn gen_ndarray_copy<'ctx>(
|
|||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
|
|
||||||
let this_ty = obj.as_ref().unwrap().0;
|
let this_ty = obj.as_ref().unwrap().0;
|
||||||
let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty);
|
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
||||||
let this_arg = obj
|
let this_arg = obj
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -13,7 +13,7 @@ use crate::{
|
|||||||
toplevel::{
|
toplevel::{
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
helper::PRIMITIVE_DEF_IDS,
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
numpy::unpack_ndarray_tvars,
|
numpy::unpack_ndarray_var_tys,
|
||||||
TopLevelDef,
|
TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
@ -251,7 +251,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
|||||||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||||
TypeEnum::TList { ty } => *ty,
|
TypeEnum::TList { ty } => *ty,
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0
|
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
@ -19,13 +19,30 @@ pub fn make_ndarray_ty(
|
|||||||
dtype: Option<Type>,
|
dtype: Option<Type>,
|
||||||
ndims: Option<Type>,
|
ndims: Option<Type>,
|
||||||
) -> Type {
|
) -> Type {
|
||||||
let ndarray = primitives.ndarray;
|
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Substitutes type variables in `ndarray`.
|
||||||
|
///
|
||||||
|
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||||
|
/// specialized.
|
||||||
|
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||||
|
/// specialized.
|
||||||
|
pub fn subst_ndarray_tvars(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
ndarray: Type,
|
||||||
|
dtype: Option<Type>,
|
||||||
|
ndims: Option<Type>,
|
||||||
|
) -> Type {
|
||||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||||
};
|
};
|
||||||
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
if dtype.is_none() && ndims.is_none() {
|
||||||
|
return ndarray
|
||||||
|
}
|
||||||
|
|
||||||
let tvar_ids = params.iter()
|
let tvar_ids = params.iter()
|
||||||
.map(|(obj_id, _)| *obj_id)
|
.map(|(obj_id, _)| *obj_id)
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
@ -42,12 +59,10 @@ pub fn make_ndarray_ty(
|
|||||||
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
fn unpack_ndarray_tvars(
|
||||||
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
|
||||||
pub fn unpack_ndarray_tvars(
|
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
ndarray: Type,
|
ndarray: Type,
|
||||||
) -> (Type, Type) {
|
) -> Vec<(u32, Type)> {
|
||||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||||
};
|
};
|
||||||
@ -56,7 +71,33 @@ pub fn unpack_ndarray_tvars(
|
|||||||
|
|
||||||
params.iter()
|
params.iter()
|
||||||
.sorted_by_key(|(obj_id, _)| *obj_id)
|
.sorted_by_key(|(obj_id, _)| *obj_id)
|
||||||
.map(|(_, ty)| *ty)
|
.map(|(var_id, ty)| (*var_id, *ty))
|
||||||
|
.collect_vec()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
||||||
|
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
||||||
|
/// respectively.
|
||||||
|
pub fn unpack_ndarray_var_ids(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
ndarray: Type,
|
||||||
|
) -> (u32, u32) {
|
||||||
|
unpack_ndarray_tvars(unifier, ndarray)
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| v.0)
|
||||||
|
.collect_tuple()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
||||||
|
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
||||||
|
pub fn unpack_ndarray_var_tys(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
ndarray: Type,
|
||||||
|
) -> (Type, Type) {
|
||||||
|
unpack_ndarray_tvars(unifier, ndarray)
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| v.1)
|
||||||
.collect_tuple()
|
.collect_tuple()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ use crate::{
|
|||||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PRIMITIVE_DEF_IDS,
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
TopLevelContext,
|
TopLevelContext,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -1344,7 +1344,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
||||||
}
|
}
|
||||||
@ -1357,7 +1357,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||||
self.infer_subscript_ndarray(value, ty, ndims)
|
self.infer_subscript_ndarray(value, ty, ndims)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
@ -1389,7 +1389,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||||
|
|
||||||
let valid_index_tys = [
|
let valid_index_tys = [
|
||||||
self.primitives.int32,
|
self.primitives.int32,
|
||||||
|
Loading…
Reference in New Issue
Block a user