From 4a81ca08d263ab86ff7e61bdf2fb75845dadda54 Mon Sep 17 00:00:00 2001 From: lyken Date: Mon, 17 Jun 2024 13:34:57 +0800 Subject: [PATCH] core: move top level def type vars into `PrimitiveStore` --- nac3artiq/src/symbol_resolver.rs | 18 +- nac3core/src/codegen/builtin_fns.rs | 183 ++++++++++-------- nac3core/src/codegen/expr.rs | 47 +++-- nac3core/src/codegen/mod.rs | 27 ++- nac3core/src/codegen/numpy.rs | 6 +- nac3core/src/codegen/stmt.rs | 5 +- nac3core/src/toplevel/builtins.rs | 46 +++-- nac3core/src/toplevel/helper.rs | 50 +++-- nac3core/src/toplevel/numpy.rs | 38 ++-- nac3core/src/typecheck/magic_methods.rs | 28 +-- nac3core/src/typecheck/type_inferencer/mod.rs | 63 ++++-- .../src/typecheck/type_inferencer/test.rs | 21 +- 12 files changed, 316 insertions(+), 216 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0b9ede9..c9bf671 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -4,7 +4,7 @@ use nac3core::{ symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{ helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + numpy::{make_ndarray_ty, unpack_ndarray_params}, DefinitionId, TopLevelDef, }, typecheck::{ @@ -665,11 +665,11 @@ impl InnerResolver { } } (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { - let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); + let params = unpack_ndarray_params(unifier, primitives, extracted_ty); let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { assert!(matches!( - &*unifier.get_ty(ty), + &*unifier.get_ty(params.dtype), TypeEnum::TVar { fields: None, range, .. } if range.is_empty() )); @@ -678,10 +678,14 @@ impl InnerResolver { let actual_ty = self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; match actual_ty { - Ok(t) => match unifier.unify(ty, t) { + Ok(t) => match unifier.unify(params.dtype, t) { Ok(()) => { - let ndarray_ty = - make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims)); + let ndarray_ty = make_ndarray_ty( + unifier, + primitives, + Some(params.dtype), + Some(params.ndims), + ); Ok(Ok(ndarray_ty)) } @@ -984,7 +988,7 @@ impl InnerResolver { TypeEnum::TObj { obj_id, params, .. } if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => { - *params.iter().next().unwrap().1 + *params.get(&ctx.primitives.option_type_tvar.id).unwrap() } _ => unreachable!("must be option type"), }; diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 0e4b75f..27bd7bc 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -8,7 +8,7 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; -use crate::toplevel::numpy::unpack_ndarray_var_tys; +use crate::toplevel::numpy::unpack_ndarray_params; use crate::typecheck::typedef::Type; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. @@ -66,7 +66,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -128,7 +128,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -206,7 +206,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -273,7 +273,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -339,7 +339,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -385,7 +385,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -425,7 +425,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -490,7 +490,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -544,7 +544,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -594,7 +594,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -692,7 +692,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype; let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); @@ -792,16 +792,17 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; @@ -908,7 +909,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype; let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); @@ -1008,16 +1009,18 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; @@ -1088,7 +1091,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1128,7 +1131,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1172,7 +1175,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1216,7 +1219,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1256,7 +1259,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1296,7 +1299,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1336,7 +1339,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1376,7 +1379,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1416,7 +1419,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1456,7 +1459,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1496,7 +1499,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1536,7 +1539,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1576,7 +1579,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1616,7 +1619,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1656,7 +1659,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1696,7 +1699,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1736,7 +1739,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1776,7 +1779,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1816,7 +1819,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1856,7 +1859,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1896,7 +1899,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1936,7 +1939,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -1976,7 +1979,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2016,7 +2019,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2056,7 +2059,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2096,7 +2099,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2136,7 +2139,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2176,7 +2179,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2216,7 +2219,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2256,7 +2259,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2296,7 +2299,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); + let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -2345,16 +2348,18 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; @@ -2412,16 +2417,18 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; @@ -2479,16 +2486,18 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; @@ -2546,16 +2555,18 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; @@ -2612,12 +2623,18 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( let is_ndarray2 = x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; + let dtype = if is_ndarray1 { + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype + } else { + x1_ty + }; let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; + let x2_scalar_ty = if is_ndarray2 { + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype + } else { + x2_ty + }; numpy::ndarray_elementwise_binop_impl( generator, @@ -2669,16 +2686,18 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; @@ -2736,16 +2755,18 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype } else { unreachable!() }; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index d507548..25f64a5 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -17,7 +17,7 @@ use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{ helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + numpy::{make_ndarray_ty, unpack_ndarray_params}, DefinitionId, TopLevelDef, }, typecheck::{ @@ -150,7 +150,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { TypeEnum::TObj { obj_id, params, .. } if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => { - *params.iter().next().unwrap().1 + *params.get(&self.primitives.option_type_tvar.id).unwrap() } _ => unreachable!("must be option type"), }; @@ -166,7 +166,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { TypeEnum::TObj { obj_id, params, .. } if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => { - *params.iter().next().unwrap().1 + *params.get(&self.primitives.option_type_tvar.id).unwrap() } _ => unreachable!("must be option type"), }; @@ -188,6 +188,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &self.module, generator, &mut self.unifier, + &self.primitives, self.top_level, &mut self.type_cache, ty, @@ -205,6 +206,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &self.module, generator, &mut self.unifier, + &self.primitives, self.top_level, &mut self.type_cache, &self.primitives, @@ -1190,8 +1192,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); + let ndarray_dtype1 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty1).dtype; + let ndarray_dtype2 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty2).dtype; assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); @@ -1240,8 +1242,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(res.as_base_value().into())) } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); + let ndarray_dtype = unpack_ndarray_params( + &ctx.unifier, + &ctx.primitives, + if is_ndarray1 { ty1 } else { ty2 }, + ) + .dtype; let ndarray_val = NDArrayValue::from_ptr_val( if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), llvm_usize, @@ -1427,7 +1433,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); - let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let ndarray_dtype = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty).dtype; let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); @@ -1511,8 +1517,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let ndarray_dtype1 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, left_ty).dtype; + let ndarray_dtype2 = + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, right_ty).dtype; assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); @@ -1546,10 +1554,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(res.as_base_value().into())) } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, + let ndarray_dtype = unpack_ndarray_params( + &ctx.unifier, + &ctx.primitives, if is_ndarray1 { left_ty } else { right_ty }, - ); + ) + .dtype; let res = numpy::ndarray_elementwise_binop_impl( generator, ctx, @@ -2014,10 +2024,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) if *obj_id == *opt_id => { - ctx.get_llvm_type(generator, *params.iter().next().unwrap().1) - .ptr_type(AddressSpace::default()) - .const_null() - .into() + ctx.get_llvm_type( + generator, + *params.get(&ctx.primitives.option_type_tvar.id).unwrap(), + ) + .ptr_type(AddressSpace::default()) + .const_null() + .into() } _ => unreachable!("must be option type"), } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 517c5c9..8f839a9 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,7 +1,7 @@ use crate::{ codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_params, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, @@ -423,6 +423,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( module: &Module<'ctx>, generator: &mut G, unifier: &mut Unifier, + store: &PrimitiveStore, top_level: &TopLevelContext, type_cache: &mut HashMap>, ty: Type, @@ -443,18 +444,20 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( module, generator, unifier, + store, top_level, type_cache, - *params.iter().next().unwrap().1, + *params.get(&store.option_type_tvar.id).unwrap(), ) .ptr_type(AddressSpace::default()) .into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); + let dtype = unpack_ndarray_params(unifier, store, ty).dtype; let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, dtype, + ctx, module, generator, unifier, store, top_level, type_cache, + dtype, ); NDArrayType::new(generator, ctx, element_type).as_base_type().into() @@ -490,6 +493,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( module, generator, unifier, + store, top_level, type_cache, fields[&f.0].0, @@ -506,14 +510,17 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( let fields = ty .iter() .map(|ty| { - get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) + get_llvm_type( + ctx, module, generator, unifier, store, top_level, type_cache, *ty, + ) }) .collect_vec(); ctx.struct_type(&fields, false).into() } TList { ty } => { - let element_type = - get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty); + let element_type = get_llvm_type( + ctx, module, generator, unifier, store, top_level, type_cache, *ty, + ); ListType::new(generator, ctx, element_type).as_base_type().into() } @@ -540,6 +547,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>( module: &Module<'ctx>, generator: &mut G, unifier: &mut Unifier, + store: &PrimitiveStore, top_level: &TopLevelContext, type_cache: &mut HashMap>, primitives: &PrimitiveStore, @@ -550,7 +558,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>( return if unifier.unioned(ty, primitives.bool) { ctx.bool_type().into() } else { - get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) + get_llvm_type(ctx, module, generator, unifier, store, top_level, type_cache, ty) }; } @@ -699,6 +707,7 @@ pub fn gen_func_impl< &module, generator, &mut unifier, + &primitives, top_level_ctx.as_ref(), &mut type_cache, &primitives, @@ -715,6 +724,7 @@ pub fn gen_func_impl< &module, generator, &mut unifier, + &primitives, top_level_ctx.as_ref(), &mut type_cache, &primitives, @@ -767,6 +777,7 @@ pub fn gen_func_impl< &module, generator, &mut unifier, + &primitives, top_level_ctx.as_ref(), &mut type_cache, arg.ty, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 7f19f4e..0ee1275 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -19,7 +19,7 @@ use crate::{ symbol_resolver::ValueEnum, toplevel::{ helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + numpy::{make_ndarray_ty, unpack_ndarray_params}, DefinitionId, }, typecheck::typedef::{FunSignature, Type, TypeEnum}, @@ -1776,7 +1776,7 @@ pub fn gen_ndarray_array<'ctx>( let obj_ty = fun.0.args[0].ty; let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 + unpack_ndarray_params(&context.unifier, &context.primitives, obj_ty).dtype } TypeEnum::TList { ty } => { @@ -1916,7 +1916,7 @@ pub fn gen_ndarray_copy<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); + let this_elem_ty = unpack_ndarray_params(&context.unifier, &context.primitives, this_ty).dtype; let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index a670f11..005e262 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -10,7 +10,7 @@ use crate::{ expr::gen_binop_expr, gen_in_range_check, }, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_params, DefinitionId, TopLevelDef}, typecheck::typedef::{FunSignature, Type, TypeEnum}, }; use inkwell::{ @@ -245,7 +245,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { TypeEnum::TList { ty } => *ty, TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 + unpack_ndarray_params(&ctx.unifier, &ctx.primitives, target.custom.unwrap()) + .dtype } _ => unreachable!(), }; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 6413d99..e085b81 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -25,7 +25,7 @@ use crate::{ }, symbol_resolver::SymbolValue, toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, - typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, + typecheck::typedef::{into_var_map, TypeVar, VarMap}, }; use super::*; @@ -301,10 +301,7 @@ struct BuiltinBuilder<'a> { is_some_ty: (Type, bool), unwrap_ty: (Type, bool), - option_tvar: TypeVar, - ndarray_dtype_tvar: TypeVar, - ndarray_ndims_tvar: TypeVar, ndarray_copy_ty: (Type, bool), ndarray_fill_ty: (Type, bool), @@ -339,24 +336,19 @@ impl<'a> BuiltinBuilder<'a> { } = *primitives; // Option-related - let (is_some_ty, unwrap_ty, option_tvar) = - if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { + let (is_some_ty, unwrap_ty) = + if let TypeEnum::TObj { fields, .. } = unifier.get_ty(option).as_ref() { ( *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), - iter_type_vars(params).next().unwrap(), ) } else { unreachable!() }; - let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = - &*unifier.get_ty(ndarray) - else { + let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else { unreachable!() }; - let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap(); - let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); let ndarray_fill_ty = @@ -398,10 +390,7 @@ impl<'a> BuiltinBuilder<'a> { is_some_ty, unwrap_ty, - option_tvar, - ndarray_dtype_tvar, - ndarray_ndims_tvar, ndarray_copy_ty, ndarray_fill_ty, @@ -622,7 +611,7 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::Option => TopLevelDef::Class { name: prim.name().into(), object_id: prim.id(), - type_vars: vec![self.option_tvar.ty], + type_vars: vec![self.primitives.option_type_tvar.ty], fields: vec![], methods: vec![ Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), @@ -642,7 +631,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.unwrap_ty.0, - var_id: vec![self.option_tvar.id], + var_id: vec![self.primitives.option_type_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -656,7 +645,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().to_string(), simple_name: prim.simple_name().into(), signature: self.is_some_ty.0, - var_id: vec![self.option_tvar.id], + var_id: vec![self.primitives.option_type_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -693,13 +682,13 @@ impl<'a> BuiltinBuilder<'a> { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "n".into(), - ty: self.option_tvar.ty, + ty: self.primitives.option_type_tvar.ty, default_value: None, }], ret: self.primitives.option, - vars: into_var_map([self.option_tvar]), + vars: into_var_map([self.primitives.option_type_tvar]), })), - var_id: vec![self.option_tvar.id], + var_id: vec![self.primitives.option_type_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -735,7 +724,10 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::NDArray => TopLevelDef::Class { name: prim.name().into(), object_id: prim.id(), - type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty], + type_vars: vec![ + self.primitives.ndarray_dtype_tvar.ty, + self.primitives.ndarray_ndims_tvar.ty, + ], fields: Vec::default(), methods: vec![ Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), @@ -751,7 +743,10 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.ndarray_copy_ty.0, - var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], + var_id: vec![ + self.primitives.ndarray_dtype_tvar.id, + self.primitives.ndarray_ndims_tvar.id, + ], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -768,7 +763,10 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.ndarray_fill_ty.0, - var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], + var_id: vec![ + self.primitives.ndarray_dtype_tvar.id, + self.primitives.ndarray_ndims_tvar.id, + ], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 2fcc24c..7051c88 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,8 +1,8 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; -use crate::toplevel::numpy::unpack_ndarray_var_tys; -use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap}; +use crate::toplevel::numpy::unpack_ndarray_params; +use crate::typecheck::typedef::{into_var_map, Mapping, TypeVar, TypeVarId, VarMap}; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -286,6 +286,18 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef ] } +pub fn make_option_type_tvar(unifier: &mut Unifier) -> TypeVar { + unifier.get_fresh_var(Some("option_type_var".into()), None) +} + +pub fn make_ndarray_dtype_tvar(unifier: &mut Unifier) -> TypeVar { + unifier.get_fresh_var(Some("ndarray_dtype".into()), None) +} + +pub fn make_ndarray_ndims_tvar(unifier: &mut Unifier, size_ty: Type) -> TypeVar { + unifier.get_fresh_const_generic_var(size_ty, Some("ndarray_ndims".into()), None) +} + impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { @@ -381,16 +393,16 @@ impl TopLevelComposer { params: VarMap::new(), }); - let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None); + let option_type_tvar = make_option_type_tvar(&mut unifier); let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: bool, - vars: into_var_map([option_type_var]), + vars: into_var_map([option_type_tvar]), })); let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], - ret: option_type_var.ty, - vars: into_var_map([option_type_var]), + ret: option_type_tvar.ty, + vars: into_var_map([option_type_tvar]), })); let option = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Option.id(), @@ -401,7 +413,7 @@ impl TopLevelComposer { ] .into_iter() .collect::>(), - params: into_var_map([option_type_var]), + params: into_var_map([option_type_tvar]), }); let size_t_ty = match size_t { @@ -410,9 +422,8 @@ impl TopLevelComposer { _ => unreachable!(), }; - let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); - let ndarray_ndims_tvar = - unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier); + let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, size_t_ty); let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], @@ -451,7 +462,10 @@ impl TopLevelComposer { str, exception, option, + option_type_tvar, ndarray, + ndarray_dtype_tvar, + ndarray_ndims_tvar, size_t, }; unifier.put_primitive_store(&primitives); @@ -881,22 +895,26 @@ pub fn parse_parameter_default_value( } /// Obtains the element type of an array-like type. -pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { +pub fn arraylike_flatten_element_type( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, +) -> Type { match &*unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(unifier, ty).0 + unpack_ndarray_params(unifier, store, ty).dtype } - TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), + TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, store, *ty), _ => ty, } } /// Obtains the number of dimensions of an array-like type. -pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { +pub fn arraylike_get_ndims(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) -> u64 { match &*unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let ndims = unpack_ndarray_var_tys(unifier, ty).1; + let ndims = unpack_ndarray_params(unifier, store, ty).ndims; let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) }; @@ -908,7 +926,7 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { u64::try_from(values[0].clone()).unwrap() } - TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1, + TypeEnum::TList { ty } => arraylike_get_ndims(unifier, store, *ty) + 1, _ => 0, } } diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 63f6173..c3036cd 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -2,7 +2,7 @@ use crate::{ toplevel::helper::PrimDef, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, + typedef::{Type, TypeEnum, Unifier, VarMap}, }, }; use itertools::Itertools; @@ -57,29 +57,25 @@ pub fn subst_ndarray_tvars( unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) } -fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> { +#[derive(Clone, Copy, Debug)] +pub struct NDArrayParams { + pub dtype: Type, + pub ndims: Type, +} + +/// Extract the [`Type`]s of `ndarray`. +#[must_use] +pub fn unpack_ndarray_params( + unifier: &Unifier, + store: &PrimitiveStore, + ndarray: Type, +) -> NDArrayParams { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); debug_assert_eq!(params.len(), 2); - - params - .iter() - .sorted_by_key(|(obj_id, _)| *obj_id) - .map(|(var_id, ty)| (*var_id, *ty)) - .collect_vec() -} - -/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds -/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` -/// respectively. -pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) { - unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap() -} - -/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to -/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. -pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) { - unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap() + let dtype = *params.get(&store.ndarray_dtype_tvar.id).unwrap(); + let ndims = *params.get(&store.ndarray_ndims_tvar.id).unwrap(); + NDArrayParams { dtype, ndims } } diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index f2b995e..3550361 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,6 +1,6 @@ use crate::symbol_resolver::SymbolValue; use crate::toplevel::helper::PrimDef; -use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; +use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_params}; use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, @@ -369,16 +369,16 @@ pub fn typeof_ndarray_broadcast( if is_left_ndarray && is_right_ndarray { // Perform broadcasting on two ndarray operands. - let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left); - let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right); + let left_params = unpack_ndarray_params(unifier, primitives, left); + let right_params = unpack_ndarray_params(unifier, primitives, right); - assert!(unifier.unioned(left_ty_dtype, right_ty_dtype)); + assert!(unifier.unioned(left_params.dtype, right_params.dtype)); - let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) { + let left_ty_ndims = match &*unifier.get_ty_immutable(left_params.ndims) { TypeEnum::TLiteral { values, .. } => values.clone(), _ => unreachable!(), }; - let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) { + let right_ty_ndims = match &*unifier.get_ty_immutable(right_params.ndims) { TypeEnum::TLiteral { values, .. } => values.clone(), _ => unreachable!(), }; @@ -397,11 +397,11 @@ pub fn typeof_ndarray_broadcast( .collect_vec(); let res_ndims = unifier.get_fresh_literal(res_ndims, None); - Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) + Ok(make_ndarray_ty(unifier, primitives, Some(left_params.dtype), Some(res_ndims))) } else { let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) }; - let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); + let ndarray_ty_dtype = unpack_ndarray_params(unifier, primitives, ndarray_ty).dtype; if unifier.unioned(ndarray_ty_dtype, scalar_ty) { Ok(ndarray_ty) @@ -444,7 +444,7 @@ pub fn typeof_binop( } Operator::MatMult => { - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = unpack_ndarray_params(unifier, primitives, lhs).ndims; let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { TypeEnum::TLiteral { values, .. } => { assert_eq!(values.len(), 1); @@ -452,7 +452,7 @@ pub fn typeof_binop( } _ => unreachable!(), }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = unpack_ndarray_params(unifier, primitives, rhs).ndims; let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { TypeEnum::TLiteral { values, .. } => { assert_eq!(values.len(), 1); @@ -552,7 +552,7 @@ pub fn typeof_unaryop( Unaryop::UAdd | Unaryop::USub => { if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) { - let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); + let dtype = unpack_ndarray_params(unifier, primitives, operand).dtype; if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) { return Err(if op == Unaryop::UAdd { "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() @@ -586,7 +586,7 @@ pub fn typeof_cmpop( Ok(Some(if is_left_ndarray || is_right_ndarray { let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; - let (_, ndims) = unpack_ndarray_var_tys(unifier, brd); + let ndims = unpack_ndarray_params(unifier, primitives, brd).ndims; make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims)) } else if unifier.unioned(lhs, rhs) { @@ -653,8 +653,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty)); - let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); - let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); + let ndarray_dtype_t = unpack_ndarray_params(unifier, store, ndarray_t).dtype; + let ndarray_unsized_dtype_t = unpack_ndarray_params(unifier, store, ndarray_unsized_t).dtype; impl_basic_arithmetic( unifier, store, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 251b91e..e0e03dd 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -4,13 +4,15 @@ use std::iter::once; use std::ops::Not; use std::{cell::RefCell, sync::Arc}; -use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; +use super::typedef::{ + Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, TypeVar, Unifier, VarMap, +}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + numpy::{make_ndarray_ty, unpack_ndarray_params}, TopLevelContext, }, }; @@ -49,7 +51,11 @@ pub struct PrimitiveStore { pub str: Type, pub exception: Type, pub option: Type, + /// The contained type of an `Option` + pub option_type_tvar: TypeVar, pub ndarray: Type, + pub ndarray_dtype_tvar: TypeVar, + pub ndarray_ndims_tvar: TypeVar, pub size_t: u32, } @@ -896,7 +902,8 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + let ndarray_ndims = + unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims; make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) } else { @@ -934,9 +941,7 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); - - ndarray_dtype + unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype } else { arg0_ty }; @@ -988,14 +993,14 @@ impl<'a> Inferencer<'a> { let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - unpack_ndarray_var_tys(self.unifier, arg0_ty).0 + unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype } else { arg0_ty }; let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - unpack_ndarray_var_tys(self.unifier, arg1_ty).0 + unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).dtype } else { arg1_ty }; @@ -1026,7 +1031,8 @@ impl<'a> Inferencer<'a> { // (float, int32), so convert it to align with the dtype of the first arg let arg1_ty = if id == &"np_ldexp".into() { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); + let ndims = + unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).ndims; make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) } else { @@ -1115,7 +1121,8 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + let ndarray_ndims = + unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims; make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) } else { @@ -1258,7 +1265,8 @@ impl<'a> Inferencer<'a> { let ndmin_kw = keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into())); - let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap()); + let ty = + arraylike_flatten_element_type(self.unifier, self.primitives, arg0.custom.unwrap()); let ndims = if let Some(ndmin_kw) = ndmin_kw { match &ndmin_kw.node.value.node { ExprKind::Constant { value, .. } => match value { @@ -1266,10 +1274,10 @@ impl<'a> Inferencer<'a> { _ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])), }, - _ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()), + _ => arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap()), } } else { - arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) + arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap()) }; let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); @@ -1666,8 +1674,12 @@ impl<'a> Inferencer<'a> { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = unpack_ndarray_params( + self.unifier, + self.primitives, + value.custom.unwrap(), + ) + .ndims; make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) } @@ -1680,8 +1692,13 @@ impl<'a> Inferencer<'a> { ExprKind::Constant { value: ast::Constant::Int(val), .. } => { match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = unpack_ndarray_params( + self.unifier, + self.primitives, + value.custom.unwrap(), + ) + .ndims; + self.infer_subscript_ndarray(value, ty, ndims) } _ => { @@ -1724,7 +1741,9 @@ impl<'a> Inferencer<'a> { } } - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = + unpack_ndarray_params(self.unifier, self.primitives, value.custom.unwrap()) + .ndims; let ndarray_ty = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?; @@ -1751,8 +1770,12 @@ impl<'a> Inferencer<'a> { Ok(ty) } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = unpack_ndarray_params( + self.unifier, + self.primitives, + value.custom.unwrap(), + ) + .ndims; let valid_index_tys = [self.primitives.int32, self.primitives.isize()] .into_iter() diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index b5e2b1e..44ee01b 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -1,5 +1,8 @@ use super::super::{magic_methods::with_fields, typedef::*}; use super::*; +use crate::toplevel::helper::{ + make_ndarray_dtype_tvar, make_ndarray_ndims_tvar, make_option_type_tvar, +}; use crate::{ codegen::CodeGenContext, symbol_resolver::ValueEnum, @@ -132,14 +135,14 @@ impl TestEnvironment { fields: HashMap::new(), params: VarMap::new(), }); + let option_type_tvar = make_option_type_tvar(&mut unifier); let option = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Option.id(), fields: HashMap::new(), params: VarMap::new(), }); - let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); - let ndarray_ndims_tvar = - unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); + let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier); + let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), @@ -157,7 +160,10 @@ impl TestEnvironment { uint32, uint64, option, + option_type_tvar, ndarray, + ndarray_dtype_tvar, + ndarray_ndims_tvar, size_t: 64, }; unifier.put_primitive_store(&primitives); @@ -268,16 +274,22 @@ impl TestEnvironment { fields: HashMap::new(), params: VarMap::new(), }); + + let option_type_tvar = make_option_type_tvar(&mut unifier); let option = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Option.id(), fields: HashMap::new(), params: VarMap::new(), }); + + let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier); + let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), params: VarMap::new(), }); + identifier_mapping.insert("None".into(), none); for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] .iter() @@ -312,7 +324,10 @@ impl TestEnvironment { uint32, uint64, option, + option_type_tvar, ndarray, + ndarray_dtype_tvar, + ndarray_ndims_tvar, size_t: 64, };