forked from M-Labs/nac3
1
0
Fork 0

core/toplevel/numpy: Split ndarray type var utilities

This commit is contained in:
David Mak 2024-03-26 19:14:56 +08:00
parent 87bc34f7ec
commit 3a6c53d760
6 changed files with 59 additions and 18 deletions

View File

@ -5,7 +5,7 @@ use nac3core::{
toplevel::{ toplevel::{
DefinitionId, DefinitionId,
helper::PRIMITIVE_DEF_IDS, helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_tvars}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef, TopLevelDef,
}, },
typecheck::{ typecheck::{
@ -665,7 +665,7 @@ impl InnerResolver {
} }
} }
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty); let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 { if len == 0 {
assert!(matches!( assert!(matches!(

View File

@ -2,7 +2,7 @@ use crate::{
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{ toplevel::{
helper::PRIMITIVE_DEF_IDS, helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_tvars, numpy::unpack_ndarray_var_tys,
TopLevelContext, TopLevelContext,
TopLevelDef, TopLevelDef,
}, },
@ -451,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let llvm_usize = generator.get_size_type(ctx); let llvm_usize = generator.get_size_type(ctx);
let (dtype, _) = unpack_ndarray_tvars(unifier, ty); let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type( let element_type = get_llvm_type(
ctx, ctx,
module, module,

View File

@ -27,7 +27,7 @@ use crate::{
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{
DefinitionId, DefinitionId,
numpy::{make_ndarray_ty, unpack_ndarray_tvars}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
}, },
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, Type},
}; };
@ -748,7 +748,7 @@ pub fn gen_ndarray_copy<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx); let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty); let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg = obj let this_arg = obj
.as_ref() .as_ref()
.unwrap() .unwrap()

View File

@ -13,7 +13,7 @@ use crate::{
toplevel::{ toplevel::{
DefinitionId, DefinitionId,
helper::PRIMITIVE_DEF_IDS, helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_tvars, numpy::unpack_ndarray_var_tys,
TopLevelDef, TopLevelDef,
}, },
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, Type, TypeEnum},
@ -251,7 +251,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty, TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0 unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
} }
_ => unreachable!(), _ => unreachable!(),
}; };

View File

@ -19,13 +19,30 @@ pub fn make_ndarray_ty(
dtype: Option<Type>, dtype: Option<Type>,
ndims: Option<Type>, ndims: Option<Type>,
) -> Type { ) -> Type {
let ndarray = primitives.ndarray; subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
}; };
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
if dtype.is_none() && ndims.is_none() {
return ndarray
}
let tvar_ids = params.iter() let tvar_ids = params.iter()
.map(|(obj_id, _)| *obj_id) .map(|(obj_id, _)| *obj_id)
.collect_vec(); .collect_vec();
@ -42,12 +59,10 @@ pub fn make_ndarray_ty(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
} }
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to fn unpack_ndarray_tvars(
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_tvars(
unifier: &mut Unifier, unifier: &mut Unifier,
ndarray: Type, ndarray: Type,
) -> (Type, Type) { ) -> Vec<(u32, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
}; };
@ -56,7 +71,33 @@ pub fn unpack_ndarray_tvars(
params.iter() params.iter()
.sorted_by_key(|(obj_id, _)| *obj_id) .sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(_, ty)| *ty) .map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(
unifier: &mut Unifier,
ndarray: Type,
) -> (u32, u32) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.0)
.collect_tuple()
.unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(
unifier: &mut Unifier,
ndarray: Type,
) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.1)
.collect_tuple() .collect_tuple()
.unwrap() .unwrap()
} }

View File

@ -9,7 +9,7 @@ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
helper::PRIMITIVE_DEF_IDS, helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_tvars}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext, TopLevelContext,
}, },
}; };
@ -1344,7 +1344,7 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
} }
@ -1357,7 +1357,7 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => { ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) { match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.infer_subscript_ndarray(value, ty, ndims) self.infer_subscript_ndarray(value, ty, ndims)
} }
_ => { _ => {
@ -1389,7 +1389,7 @@ impl<'a> Inferencer<'a> {
Ok(ty) Ok(ty)
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let valid_index_tys = [ let valid_index_tys = [
self.primitives.int32, self.primitives.int32,