diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 4021a56..32c0f9d 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -5,7 +5,7 @@ use nac3core::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, typecheck::{ @@ -665,7 +665,7 @@ impl InnerResolver { } } (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty); + let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { assert!(matches!( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index c528eab..c07185e 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -2,7 +2,7 @@ use crate::{ symbol_resolver::{StaticValue, SymbolResolver}, toplevel::{ helper::PRIMITIVE_DEF_IDS, - numpy::unpack_ndarray_tvars, + numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef, }, @@ -451,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { let llvm_usize = generator.get_size_type(ctx); - let (dtype, _) = unpack_ndarray_tvars(unifier, ty); + let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let element_type = get_llvm_type( ctx, module, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 4cc8407..bb05ef9 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -27,7 +27,7 @@ use crate::{ symbol_resolver::ValueEnum, toplevel::{ DefinitionId, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, typecheck::typedef::{FunSignature, Type}, }; @@ -748,7 +748,7 @@ pub fn gen_ndarray_copy<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty); + let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj .as_ref() .unwrap() diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 61018a3..1e0ab2c 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -13,7 +13,7 @@ use crate::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::unpack_ndarray_tvars, + numpy::unpack_ndarray_var_tys, TopLevelDef, }, typecheck::typedef::{FunSignature, Type, TypeEnum}, @@ -251,7 +251,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { TypeEnum::TList { ty } => *ty, TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0 + unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 } _ => unreachable!(), }; diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index d322519..aee0904 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -19,13 +19,30 @@ pub fn make_ndarray_ty( dtype: Option, ndims: Option, ) -> Type { - let ndarray = primitives.ndarray; + subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims) +} +/// Substitutes type variables in `ndarray`. +/// +/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not +/// specialized. +/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not +/// specialized. +pub fn subst_ndarray_tvars( + unifier: &mut Unifier, + ndarray: Type, + dtype: Option, + ndims: Option, +) -> Type { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); + if dtype.is_none() && ndims.is_none() { + return ndarray + } + let tvar_ids = params.iter() .map(|(obj_id, _)| *obj_id) .collect_vec(); @@ -42,12 +59,10 @@ pub fn make_ndarray_ty( unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) } -/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to -/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. -pub fn unpack_ndarray_tvars( +fn unpack_ndarray_tvars( unifier: &mut Unifier, ndarray: Type, -) -> (Type, Type) { +) -> Vec<(u32, Type)> { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; @@ -56,7 +71,33 @@ pub fn unpack_ndarray_tvars( params.iter() .sorted_by_key(|(obj_id, _)| *obj_id) - .map(|(_, ty)| *ty) + .map(|(var_id, ty)| (*var_id, *ty)) + .collect_vec() +} + +/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds +/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` +/// respectively. +pub fn unpack_ndarray_var_ids( + unifier: &mut Unifier, + ndarray: Type, +) -> (u32, u32) { + unpack_ndarray_tvars(unifier, ndarray) + .into_iter() + .map(|v| v.0) + .collect_tuple() + .unwrap() +} + +/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to +/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. +pub fn unpack_ndarray_var_tys( + unifier: &mut Unifier, + ndarray: Type, +) -> (Type, Type) { + unpack_ndarray_tvars(unifier, ndarray) + .into_iter() + .map(|v| v.1) .collect_tuple() .unwrap() } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index f1f2a5a..5fb89e0 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -9,7 +9,7 @@ use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ helper::PRIMITIVE_DEF_IDS, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, }; @@ -1344,7 +1344,7 @@ impl<'a> Inferencer<'a> { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) } @@ -1357,7 +1357,7 @@ impl<'a> Inferencer<'a> { ExprKind::Constant { value: ast::Constant::Int(val), .. } => { match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.infer_subscript_ndarray(value, ty, ndims) } _ => { @@ -1389,7 +1389,7 @@ impl<'a> Inferencer<'a> { Ok(ty) } TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); let valid_index_tys = [ self.primitives.int32,