From 6892a4848e8b7605c0fe523036b1370b74d21c46 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 27 Jun 2024 16:33:27 +0800 Subject: [PATCH] core/typedef: WIP - Add OptionType and NDArrayType --- nac3artiq/src/codegen.rs | 10 +- nac3artiq/src/symbol_resolver.rs | 53 +++-- nac3core/src/codegen/builtin_fns.rs | 202 ++++++++++++---- nac3core/src/codegen/expr.rs | 82 +++---- nac3core/src/codegen/mod.rs | 14 +- nac3core/src/codegen/numpy.rs | 25 +- nac3core/src/codegen/stmt.rs | 8 +- nac3core/src/symbol_resolver.rs | 9 +- nac3core/src/toplevel/builtins.rs | 218 +++++++++++------- nac3core/src/toplevel/helper.rs | 17 +- nac3core/src/toplevel/mod.rs | 2 +- nac3core/src/toplevel/numpy.rs | 85 ------- nac3core/src/toplevel/primitive_type.rs | 98 ++++++++ nac3core/src/toplevel/type_annotation.rs | 9 +- nac3core/src/typecheck/magic_methods.rs | 156 +++++++++---- nac3core/src/typecheck/type_inferencer/mod.rs | 129 +++++++---- .../src/typecheck/type_inferencer/test.rs | 4 + nac3core/src/typecheck/typedef/mod.rs | 3 +- 18 files changed, 717 insertions(+), 407 deletions(-) delete mode 100644 nac3core/src/toplevel/numpy.rs create mode 100644 nac3core/src/toplevel/primitive_type.rs diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 994c5048..0815e888 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -7,7 +7,7 @@ use nac3core::{ }, symbol_resolver::ValueEnum, toplevel::{helper::PrimDef, DefinitionId, GenCall}, - typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap}, + typecheck::typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, VarMap}, }; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; @@ -23,7 +23,7 @@ use pyo3::{ use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; -use nac3core::toplevel::numpy::unpack_ndarray_var_tys; +use nac3core::toplevel::primitive_type; use std::{ collections::hash_map::DefaultHasher, collections::HashMap, @@ -399,7 +399,9 @@ fn gen_rpc_tag( gen_rpc_tag(ctx, *ty, buffer)?; } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let ndarray_ty = primitive_type::NDArrayType::create(ty, &mut ctx.unifier); + let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty; + let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty; let ndarray_ndims = if let TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims) { @@ -645,7 +647,7 @@ pub fn attributes_writeback( let ty = ty.unwrap(); match &*ctx.unifier.get_ty(ty) { TypeEnum::TObj { fields, obj_id, .. } - if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => + if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier) => { // we only care about primitive attributes // for non-primitive attributes, they should be in another global diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 5289e6d1..ec518c7e 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -11,11 +11,7 @@ use nac3core::{ CodeGenContext, CodeGenerator, }, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, - toplevel::{ - helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, primitive_type, DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap}, @@ -337,13 +333,18 @@ impl InnerResolver { // do not handle type var param and concrete check here let var = unifier.get_dummy_var().ty; let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty; - let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims)); - Ok(Ok((ndarray, false))) + let ndarray = primitive_type::NDArrayType::from_primitive( + unifier, + primitives, + Some(var), + Some(ndims), + ); + Ok(Ok((ndarray.into(), false))) } else if ty_id == self.primitive_ids.tuple { // do not handle type var param and concrete check here Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) } else if ty_id == self.primitive_ids.option { - Ok(Ok((primitives.option, false))) + Ok(Ok((primitives.option.into(), false))) } else if ty_id == self.primitive_ids.none { unreachable!("none cannot be typeid") } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { @@ -510,7 +511,16 @@ impl InnerResolver { )); } - Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true))) + Ok(Ok(( + primitive_type::NDArrayType::from_primitive( + unifier, + primitives, + Some(ty.0), + None, + ) + .into(), + true, + ))) } TypeEnum::TTuple { .. } => { let args = match args @@ -719,7 +729,9 @@ impl InnerResolver { } } (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { - let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); + let ndarray = primitive_type::NDArrayType::create(extracted_ty, unifier); + let ty = ndarray.dtype_tvar(unifier).ty; + let ndims = ndarray.ndims_tvar(unifier).ty; let len: usize = obj.getattr("ndim")?.extract()?; if len == 0 { assert!(matches!( @@ -734,10 +746,14 @@ impl InnerResolver { match dtype_ty { Ok((t, _)) => match unifier.unify(ty, t) { Ok(()) => { - let ndarray_ty = - make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims)); + let ndarray_ty = primitive_type::NDArrayType::from_primitive( + unifier, + primitives, + Some(ty), + Some(ndims), + ); - Ok(Ok(ndarray_ty)) + Ok(Ok(ndarray_ty.into())) } Err(e) => Ok(Err(format!( "type error ({}) for the ndarray", @@ -760,7 +776,7 @@ impl InnerResolver { // special handling for option type since its class member layout in python side // is special and cannot be mapped directly to a nac3 type as below (TypeEnum::TObj { obj_id, params, .. }, false) - if *obj_id == primitives.option.obj_id(unifier).unwrap() => + if *obj_id == primitives.option.obj_id(unifier) => { let Ok(field_data) = obj.getattr("_nac3_option") else { unreachable!("cannot be None") @@ -785,7 +801,7 @@ impl InnerResolver { .map(TypeVar::into) .collect::() }); - return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())); + return Ok(Ok(unifier.subst(primitives.option.into(), &var_map).unwrap())); } let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { @@ -1038,8 +1054,9 @@ impl InnerResolver { } else { unreachable!("must be ndarray") }; - let (ndarray_dtype, ndarray_ndims) = - unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); + let ndarray_ty = primitive_type::NDArrayType::create(ndarray_ty, &mut ctx.unifier); + let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty; + let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty; let llvm_usize = generator.get_size_type(ctx.ctx); let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); @@ -1186,7 +1203,7 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.option { let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { TypeEnum::TObj { obj_id, params, .. } - if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => + if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier) => { *params.iter().next().unwrap().1 } diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index feec51ea..814df426 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -8,8 +8,8 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; -use crate::toplevel::numpy::unpack_ndarray_var_tys; -use crate::typecheck::typedef::Type; +use crate::toplevel::primitive_type; +use crate::typecheck::typedef::{GenericObjectType, Type}; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// @@ -66,7 +66,9 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -128,7 +130,9 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -206,7 +210,9 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -273,7 +279,9 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -328,7 +336,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -374,7 +384,9 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -414,7 +426,9 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -475,7 +489,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -529,7 +545,9 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -579,7 +597,9 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray = ndarray_elementwise_unaryop_impl( generator, @@ -660,7 +680,9 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); @@ -751,16 +773,24 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; @@ -850,7 +880,9 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); @@ -941,16 +973,24 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; @@ -1008,7 +1048,9 @@ where if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let llvm_usize = generator.get_size_type(ctx.ctx); - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); + let arg_elem_ty = primitive_type::NDArrayType::create(arg_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1370,16 +1412,24 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; @@ -1437,16 +1487,24 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; @@ -1504,16 +1562,24 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; @@ -1571,16 +1637,24 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; @@ -1637,12 +1711,22 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( let is_ndarray2 = x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; + let dtype = if is_ndarray1 { + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty + } else { + x1_ty + }; let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; + let x2_scalar_ty = if is_ndarray2 { + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty + } else { + x2_ty + }; numpy::ndarray_elementwise_binop_impl( generator, @@ -1694,16 +1778,24 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; @@ -1761,16 +1853,24 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); ndarray_dtype1 } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 + primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 + primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } else { unreachable!() }; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 2d188058..5ab6f269 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,5 +1,9 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; +use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator}; +use crate::toplevel::primitive_type; +use crate::toplevel::primitive_type::OptionType; +use crate::typecheck::typedef::GenericObjectType; use crate::{ codegen::{ classes::{ @@ -15,11 +19,7 @@ use crate::{ CodeGenContext, CodeGenTask, }, symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{ - helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{binop_assign_name, binop_name, unaryop_name}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -36,8 +36,6 @@ use nac3parser::ast::{ self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, }; -use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator}; - pub fn get_subst_key( unifier: &mut Unifier, obj: Option, @@ -162,14 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { self.builder.build_load(ptr, "tup_val").unwrap() } SymbolValue::OptionSome(v) => { - let ty = match self.unifier.get_ty_immutable(ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } - if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => - { - *params.iter().next().unwrap().1 - } - _ => unreachable!("must be option type"), - }; + let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty; let val = self.gen_symbol_val(generator, v, ty); let ptr = generator .gen_var_alloc(self, val.get_type(), Some("default_opt_some")) @@ -178,14 +169,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ptr.into() } SymbolValue::OptionNone => { - let ty = match self.unifier.get_ty_immutable(ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } - if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => - { - *params.iter().next().unwrap().1 - } - _ => unreachable!("must be option type"), - }; + let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty; let actual_ptr_type = self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default()); actual_ptr_type.const_null().into() @@ -1206,8 +1190,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); + let ndarray_dtype1 = primitive_type::NDArrayType::create(ty1, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = primitive_type::NDArrayType::create(ty2, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); @@ -1256,8 +1244,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(res.as_base_value().into())) } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); + let ndarray_dtype = primitive_type::NDArrayType::create( + if is_ndarray1 { ty1 } else { ty2 }, + &mut ctx.unifier, + ) + .dtype_tvar(&mut ctx.unifier) + .ty; let ndarray_val = NDArrayValue::from_ptr_val( if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), llvm_usize, @@ -1443,7 +1435,9 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); - let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let ndarray_dtype = primitive_type::NDArrayType::create(ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); @@ -1527,8 +1521,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let ndarray_dtype1 = primitive_type::NDArrayType::create(left_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; + let ndarray_dtype2 = + primitive_type::NDArrayType::create(right_ty, &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty; assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); @@ -1562,10 +1561,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(res.as_base_value().into())) } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, + let ndarray_dtype = primitive_type::NDArrayType::create( if is_ndarray1 { left_ty } else { right_ty }, - ); + &mut ctx.unifier, + ) + .dtype_tvar(&mut ctx.unifier) + .ty; let res = numpy::ndarray_elementwise_binop_impl( generator, ctx, @@ -1788,9 +1789,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(), None, ); - let ndarray_ty = - make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); - let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); + let ndarray_ty = primitive_type::NDArrayType::from_primitive( + &mut ctx.unifier, + &ctx.primitives, + Some(ty), + Some(ndarray_ndims_ty), + ); + let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty.into()).into_pointer_type(); let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); @@ -2082,7 +2087,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ExprKind::Name { id, .. } if id == &"none".into() => { match ( ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(), - ctx.unifier.get_ty(ctx.primitives.option).as_ref(), + ctx.unifier.get_ty(ctx.primitives.option.into()).as_ref(), ) { (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) if *obj_id == *opt_id => @@ -2464,8 +2469,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant - if attr == &"unwrap".into() - && id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() + if attr == &"unwrap".into() && id == ctx.primitives.option.obj_id(&ctx.unifier) { match val { ValueEnum::Static(v) => { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index eb3b9d95..b26dae9d 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,7 +1,7 @@ use crate::{ codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, + toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, @@ -47,6 +47,9 @@ pub mod stmt; #[cfg(test)] mod test; +use crate::toplevel::primitive_type; +use crate::toplevel::primitive_type::OptionType; +use crate::typecheck::typedef::GenericObjectType; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; @@ -457,7 +460,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); + let dtype = primitive_type::NDArrayType::create(ty, unifier) + .dtype_tvar(unifier) + .ty; let element_type = get_llvm_type( ctx, module, generator, unifier, top_level, type_cache, dtype, ); @@ -634,7 +639,10 @@ pub fn gen_func_impl< range: unifier.get_representative(primitives.range), str: unifier.get_representative(primitives.str), exception: unifier.get_representative(primitives.exception), - option: unifier.get_representative(primitives.option), + option: OptionType::create( + unifier.get_representative(primitives.option.into()), + &mut unifier, + ), ..primitives }; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 3fab259e..c38ff3a2 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -17,12 +17,8 @@ use crate::{ CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, - toplevel::{ - helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, - }, - typecheck::typedef::{FunSignature, Type, TypeEnum}, + toplevel::{helper::PrimDef, primitive_type, DefinitionId}, + typecheck::typedef::{FunSignature, GenericObjectType, Type, TypeEnum}, }; use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::{ @@ -38,12 +34,17 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, ) -> Result, String> { - let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); + let ndarray_ty = primitive_type::NDArrayType::from_primitive( + &mut ctx.unifier, + &ctx.primitives, + Some(elem_ty), + None, + ); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_ndarray_t = ctx - .get_llvm_type(generator, ndarray_ty) + .get_llvm_type(generator, ndarray_ty.into()) .into_pointer_type() .get_element_type() .into_struct_type(); @@ -1799,7 +1800,9 @@ pub fn gen_ndarray_array<'ctx>( let obj_ty = fun.0.args[0].ty; let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 + primitive_type::NDArrayType::create(obj_ty, &mut context.unifier) + .dtype_tvar(&mut context.unifier) + .ty } TypeEnum::TList { ty } => { @@ -1939,7 +1942,9 @@ pub fn gen_ndarray_copy<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); + let this_elem_ty = primitive_type::NDArrayType::create(this_ty, &mut context.unifier) + .dtype_tvar(&mut context.unifier) + .ty; let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 5bae9a94..649bcf55 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -4,13 +4,15 @@ use super::{ irrt::{handle_slice_indices, list_slice_assignment}, CodeGenContext, CodeGenerator, }; +use crate::toplevel::primitive_type; +use crate::typecheck::typedef::GenericObjectType; use crate::{ codegen::{ classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, expr::gen_binop_expr, gen_in_range_check, }, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, typecheck::typedef::{FunSignature, Type, TypeEnum}, }; use inkwell::{ @@ -245,7 +247,9 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { TypeEnum::TList { ty } => *ty, TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 + primitive_type::NDArrayType::create(target.custom.unwrap(), &mut ctx.unifier) + .dtype_tvar(&mut ctx.unifier) + .ty } _ => unreachable!(), }; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 753cc076..b4448776 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -3,6 +3,7 @@ use std::rc::Rc; use std::sync::Arc; use std::{collections::HashMap, collections::HashSet, fmt::Display}; +use crate::typecheck::typedef::GenericObjectType; use crate::{ codegen::{CodeGenContext, CodeGenerator}, toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef}, @@ -43,7 +44,7 @@ impl SymbolValue { ) -> Result { match constant { Constant::None => { - if unifier.unioned(expected_ty, primitives.option) { + if unifier.unioned(expected_ty, primitives.option.into()) { Ok(SymbolValue::OptionNone) } else { Err(format!("Expected {expected_ty:?}, but got Option")) @@ -157,7 +158,7 @@ impl SymbolValue { let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::>(); unifier.add_ty(TypeEnum::TTuple { ty: vs_tys }) } - SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, + SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option.into(), } } @@ -183,13 +184,13 @@ impl SymbolValue { TypeAnnotation::Tuple(vs_tys) } SymbolValue::OptionNone => TypeAnnotation::CustomClass { - id: primitives.option.obj_id(unifier).unwrap(), + id: primitives.option.obj_id(unifier), params: Vec::default(), }, SymbolValue::OptionSome(v) => { let ty = v.get_type_annotation(primitives, unifier); TypeAnnotation::CustomClass { - id: primitives.option.obj_id(unifier).unwrap(), + id: primitives.option.obj_id(unifier), params: vec![ty], } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 121038d6..2e99a4e1 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -11,7 +11,6 @@ use inkwell::{ use itertools::Either; use strum::IntoEnumIterator; -use crate::typecheck::typedef::{GenericObjectType, GenericTypeAdapter}; use crate::{ codegen::{ builtin_fns, @@ -25,7 +24,7 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, + toplevel::helper::PrimDef, typecheck::typedef::{into_var_map, TypeVar, VarMap}, }; @@ -304,10 +303,7 @@ struct BuiltinBuilder<'a> { is_some_ty: (Type, bool), unwrap_ty: (Type, bool), - option_tvar: TypeVar, - ndarray_dtype_tvar: TypeVar, - ndarray_ndims_tvar: TypeVar, ndarray_copy_ty: (Type, bool), ndarray_fill_ty: (Type, bool), @@ -316,9 +312,9 @@ struct BuiltinBuilder<'a> { num_ty: TypeVar, num_var_map: VarMap, - ndarray_float: Type, - ndarray_float_2d: Type, - ndarray_num_ty: Type, + ndarray_float: primitive_type::NDArrayType, + ndarray_float_2d: primitive_type::NDArrayType, + ndarray_num_ty: primitive_type::NDArrayType, float_or_ndarray_ty: TypeVar, float_or_ndarray_var_map: VarMap, @@ -345,24 +341,19 @@ impl<'a> BuiltinBuilder<'a> { } = *primitives; // Option-related - let (is_some_ty, unwrap_ty, option_tvar) = - if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option) { - let option = GenericTypeAdapter::create(option, unifier); + let (is_some_ty, unwrap_ty) = + if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option.into()) { ( *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), - option.get_var_at(unifier, 0).unwrap(), ) } else { unreachable!() }; - let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else { + let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray.into()) else { unreachable!() }; - let ndarray = GenericTypeAdapter::create(ndarray, unifier); - let ndarray_dtype_tvar = ndarray.get_var_at(unifier, 0).unwrap(); - let ndarray_ndims_tvar = ndarray.get_var_at(unifier, 1).unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); let ndarray_fill_ty = @@ -375,7 +366,8 @@ impl<'a> BuiltinBuilder<'a> { ); let num_var_map = into_var_map([num_ty]); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); + let ndarray_float = + primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(float), None); let ndarray_float_2d = { let value = match primitives.size_t { 64 => SymbolValue::U64(2u64), @@ -384,16 +376,28 @@ impl<'a> BuiltinBuilder<'a> { }; let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None }); - make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) + primitive_type::NDArrayType::from_primitive( + unifier, + primitives, + Some(float), + Some(ndims), + ) }; - let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None); - let float_or_ndarray_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ndarray_num_ty = + primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(num_ty.ty), None); + let float_or_ndarray_ty = unifier.get_fresh_var_with_range( + &[float, ndarray_float.into()], + Some("T".into()), + None, + ); let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]); - let num_or_ndarray_ty = - unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None); + let num_or_ndarray_ty = unifier.get_fresh_var_with_range( + &[num_ty.ty, ndarray_num_ty.into()], + Some("T".into()), + None, + ); let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]); let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); @@ -406,10 +410,7 @@ impl<'a> BuiltinBuilder<'a> { is_some_ty, unwrap_ty, - option_tvar, - ndarray_dtype_tvar, - ndarray_ndims_tvar, ndarray_copy_ty, ndarray_fill_ty, @@ -633,7 +634,7 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::Option => TopLevelDef::Class { name: prim.name().into(), object_id: prim.id(), - type_vars: vec![self.option_tvar.ty], + type_vars: vec![self.primitives.option.type_tvar(self.unifier).ty], fields: Vec::default(), attributes: Vec::default(), methods: vec![ @@ -654,7 +655,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.unwrap_ty.0, - var_id: vec![self.option_tvar.id], + var_id: vec![self.primitives.option.type_tvar(self.unifier).id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -668,7 +669,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().to_string(), simple_name: prim.simple_name().into(), signature: self.is_some_ty.0, - var_id: vec![self.option_tvar.id], + var_id: vec![self.primitives.option.type_tvar(self.unifier).id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -699,36 +700,40 @@ impl<'a> BuiltinBuilder<'a> { loc: None, }, - PrimDef::FunSome => TopLevelDef::Function { - name: prim.name().into(), - simple_name: prim.simple_name().into(), - signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "n".into(), - ty: self.option_tvar.ty, - default_value: None, - }], - ret: self.primitives.option, - vars: into_var_map([self.option_tvar]), - })), - var_id: vec![self.option_tvar.id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let alloca = generator - .gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")) - .unwrap(); - ctx.builder.build_store(alloca, arg_val).unwrap(); - Ok(Some(alloca.into())) - }, - )))), - loc: None, - }, + PrimDef::FunSome => { + let option_tvar = self.primitives.option.type_tvar(self.unifier); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "n".into(), + ty: option_tvar.ty, + default_value: None, + }], + ret: self.primitives.option.into(), + vars: into_var_map([option_tvar]), + })), + var_id: vec![self.primitives.option.type_tvar(self.unifier).id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + let alloca = generator + .gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")) + .unwrap(); + ctx.builder.build_store(alloca, arg_val).unwrap(); + Ok(Some(alloca.into())) + }, + )))), + loc: None, + } + } _ => { unreachable!() @@ -737,7 +742,7 @@ impl<'a> BuiltinBuilder<'a> { } /// Build the class `ndarray` and its associated methods. - fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { + fn build_ndarray_class_related(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed( prim, &[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill], @@ -747,7 +752,10 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::NDArray => TopLevelDef::Class { name: prim.name().into(), object_id: prim.id(), - type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty], + type_vars: vec![ + self.primitives.ndarray.dtype_tvar(self.unifier).ty, + self.primitives.ndarray.ndims_tvar(self.unifier).ty, + ], fields: Vec::default(), attributes: Vec::default(), methods: vec![ @@ -764,7 +772,10 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.ndarray_copy_ty.0, - var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], + var_id: vec![ + self.primitives.ndarray.dtype_tvar(self.unifier).id, + self.primitives.ndarray.ndims_tvar(self.unifier).id, + ], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -781,7 +792,10 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.ndarray_fill_ty.0, - var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], + var_id: vec![ + self.primitives.ndarray.dtype_tvar(self.unifier).id, + self.primitives.ndarray.ndims_tvar(self.unifier).id, + ], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -870,15 +884,26 @@ impl<'a> BuiltinBuilder<'a> { // The size variant of the function determines the size of the returned int. let int_sized = size_variant.of_int(self.primitives); - let ndarray_int_sized = - make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); - let ndarray_float = - make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); + let ndarray_int_sized = primitive_type::NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(int_sized), + Some(common_ndim.ty), + ); + let ndarray_float = primitive_type::NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(float), + Some(common_ndim.ty), + ); - let p0_ty = - self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let p0_ty = self.unifier.get_fresh_var_with_range( + &[float, ndarray_float.into()], + Some("T".into()), + None, + ); let ret_ty = self.unifier.get_fresh_var_with_range( - &[int_sized, ndarray_int_sized], + &[int_sized, ndarray_int_sized.into()], Some("R".into()), None, ); @@ -930,19 +955,30 @@ impl<'a> BuiltinBuilder<'a> { None, ); - let ndarray_float = - make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); + let ndarray_float = primitive_type::NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(float), + Some(common_ndim.ty), + ); // The size variant of the function determines the type of int returned let int_sized = size_variant.of_int(self.primitives); - let ndarray_int_sized = - make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); + let ndarray_int_sized = primitive_type::NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(int_sized), + Some(common_ndim.ty), + ); - let p0_ty = - self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let p0_ty = self.unifier.get_fresh_var_with_range( + &[float, ndarray_float.into()], + Some("T".into()), + None, + ); let ret_ty = self.unifier.get_fresh_var_with_range( - &[int_sized, ndarray_int_sized], + &[int_sized, ndarray_int_sized.into()], Some("R".into()), None, ); @@ -1005,7 +1041,7 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &VarMap::new(), prim.name(), - self.ndarray_float, + self.ndarray_float.into(), &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], Box::new(move |ctx, obj, fun, args, generator| { let func = match prim { @@ -1051,7 +1087,7 @@ impl<'a> BuiltinBuilder<'a> { default_value: Some(SymbolValue::U32(0)), }, ], - ret: ndarray, + ret: ndarray.into(), vars: into_var_map([tv]), })), var_id: vec![tv.id], @@ -1074,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &into_var_map([tv]), prim.name(), - self.primitives.ndarray, + self.primitives.ndarray.into(), // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a // type variable &[(self.list_int32, "shape"), (tv.ty, "fill_value")], @@ -1103,7 +1139,7 @@ impl<'a> BuiltinBuilder<'a> { default_value: Some(SymbolValue::I32(0)), }, ], - ret: self.ndarray_float_2d, + ret: self.ndarray_float_2d.into(), vars: VarMap::default(), })), var_id: Vec::default(), @@ -1123,7 +1159,7 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &VarMap::new(), prim.name(), - self.ndarray_float_2d, + self.ndarray_float_2d.into(), &[(int32, "n")], Box::new(|ctx, obj, fun, args, generator| { gen_ndarray_identity(ctx, &obj, fun, &args, generator) @@ -1338,10 +1374,15 @@ impl<'a> BuiltinBuilder<'a> { let tvar = self.unifier.get_fresh_var(Some("L".into()), None); let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty }); let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None); - let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty)); + let ndarray = primitive_type::NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(tvar.ty), + Some(ndims.ty), + ); let arg_ty = self.unifier.get_fresh_var_with_range( - &[list, ndarray, self.primitives.range], + &[list, ndarray.into(), self.primitives.range], Some("I".into()), None, ); @@ -1799,8 +1840,13 @@ impl<'a> BuiltinBuilder<'a> { } fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar { - let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); + let ndarray = primitive_type::NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(scalar_ty), + None, + ); - self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) + self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray.into()], Some("T".into()), None) } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 73812505..8b17ee87 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,14 +1,13 @@ use std::convert::TryInto; +use super::*; use crate::symbol_resolver::SymbolValue; -use crate::toplevel::numpy::unpack_ndarray_var_tys; -use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap}; +use crate::toplevel::primitive_type::{NDArrayType, OptionType}; +use crate::typecheck::typedef::{into_var_map, GenericObjectType, Mapping, TypeVarId, VarMap}; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; -use super::*; - /// All primitive types and functions in nac3core. #[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)] pub enum PrimDef { @@ -403,6 +402,7 @@ impl TopLevelComposer { .collect::>(), params: into_var_map([option_type_var]), }); + let option = OptionType::create(option, &mut unifier); let size_t_ty = match size_t { 32 => uint32, @@ -436,8 +436,9 @@ impl TopLevelComposer { ]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), }); + let ndarray = NDArrayType::create(ndarray, &mut unifier); - unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap(); + unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray.into()).unwrap(); let primitives = PrimitiveStore { int32, @@ -747,7 +748,7 @@ impl TopLevelComposer { TypeAnnotation::CustomClass { id: e_id, params: e_param }, ) => { *f_id == *e_id - && *f_id == primitive.option.obj_id(unifier).unwrap() + && *f_id == primitive.option.obj_id(unifier) && (f_param.is_empty() || (f_param.len() == 1 && e_param.len() == 1 @@ -885,7 +886,7 @@ pub fn parse_parameter_default_value( pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { match &*unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(unifier, ty).0 + NDArrayType::create(ty, unifier).dtype_tvar(unifier).ty } TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), @@ -897,7 +898,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { match &*unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let ndims = unpack_ndarray_var_tys(unifier, ty).1; + let ndims = NDArrayType::create(ty, unifier).ndims_tvar(unifier).ty; let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) }; diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 7dfd8373..71b574ad 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize); pub mod builtins; pub mod composer; pub mod helper; -pub mod numpy; +pub mod primitive_type; pub mod type_annotation; use composer::*; use type_annotation::*; diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs deleted file mode 100644 index 63f6173d..00000000 --- a/nac3core/src/toplevel/numpy.rs +++ /dev/null @@ -1,85 +0,0 @@ -use crate::{ - toplevel::helper::PrimDef, - typecheck::{ - type_inferencer::PrimitiveStore, - typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, - }, -}; -use itertools::Itertools; - -/// Creates a `ndarray` [`Type`] with the given type arguments. -/// -/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not -/// specialized. -/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not -/// specialized. -pub fn make_ndarray_ty( - unifier: &mut Unifier, - primitives: &PrimitiveStore, - dtype: Option, - ndims: Option, -) -> Type { - subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims) -} - -/// Substitutes type variables in `ndarray`. -/// -/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not -/// specialized. -/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not -/// specialized. -pub fn subst_ndarray_tvars( - unifier: &mut Unifier, - ndarray: Type, - dtype: Option, - ndims: Option, -) -> Type { - let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { - panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) - }; - debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); - - if dtype.is_none() && ndims.is_none() { - return ndarray; - } - - let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec(); - debug_assert_eq!(tvar_ids.len(), 2); - - let mut tvar_subst = VarMap::new(); - if let Some(dtype) = dtype { - tvar_subst.insert(tvar_ids[0], dtype); - } - if let Some(ndims) = ndims { - tvar_subst.insert(tvar_ids[1], ndims); - } - - unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) -} - -fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> { - let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { - panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) - }; - debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); - debug_assert_eq!(params.len(), 2); - - params - .iter() - .sorted_by_key(|(obj_id, _)| *obj_id) - .map(|(var_id, ty)| (*var_id, *ty)) - .collect_vec() -} - -/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds -/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` -/// respectively. -pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) { - unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap() -} - -/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to -/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. -pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) { - unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap() -} diff --git a/nac3core/src/toplevel/primitive_type.rs b/nac3core/src/toplevel/primitive_type.rs new file mode 100644 index 00000000..eaf9996c --- /dev/null +++ b/nac3core/src/toplevel/primitive_type.rs @@ -0,0 +1,98 @@ +use crate::toplevel::helper::PrimDef; +use crate::typecheck::type_inferencer::PrimitiveStore; +use crate::typecheck::typedef::{GenericObjectType, Type, TypeVar, Unifier, VarMap}; + +#[derive(Clone, Copy)] +pub struct OptionType(Type); + +impl OptionType { + pub fn from_primitive( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + type_ty: Option, + ) -> Self { + primitives.option.subst(unifier, type_ty) + } + + pub fn type_tvar(&self, unifier: &mut Unifier) -> TypeVar { + self.get_var_at(unifier, 0).unwrap() + } + + #[must_use] + pub fn subst(&self, unifier: &mut Unifier, type_ty: Option) -> Self { + let new_vars = [(self.type_tvar(unifier).id, type_ty)] + .into_iter() + .filter_map(|(id, ty)| ty.map(|ty| (id, ty))) + .collect::(); + + let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type()); + OptionType(new_ty) + } +} + +impl GenericObjectType for OptionType { + fn try_create(ty: Type, unifier: &mut Unifier) -> Option { + if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::Option.id()) { + Some(OptionType(ty)) + } else { + None + } + } + + fn get_type(&self) -> Type { + self.0 + } +} + +#[derive(Clone, Copy)] +pub struct NDArrayType(Type); + +impl NDArrayType { + pub fn from_primitive( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + dtype: Option, + ndims: Option, + ) -> Self { + primitives.ndarray.subst(unifier, dtype, ndims) + } + + pub fn dtype_tvar(&self, unifier: &mut Unifier) -> TypeVar { + self.get_var_at(unifier, 0).unwrap() + } + + pub fn ndims_tvar(&self, unifier: &mut Unifier) -> TypeVar { + self.get_var_at(unifier, 1).unwrap() + } + + #[must_use] + pub fn subst( + &self, + unifier: &mut Unifier, + dtype_ty: Option, + ndims_ty: Option, + ) -> Self { + let new_vars = + [(self.dtype_tvar(unifier).id, dtype_ty), (self.ndims_tvar(unifier).id, ndims_ty)] + .into_iter() + .filter_map(|(id, ty)| ty.map(|ty| (id, ty))) + .collect::(); + + let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type()); + NDArrayType(new_ty) + } +} + +impl GenericObjectType for NDArrayType { + fn try_create(ty: Type, unifier: &mut Unifier) -> Option { + if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { + Some(NDArrayType(ty)) + } else { + None + } + } + + fn get_type(&self) -> Type { + self.0 + } +} diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index f598badf..ab5833c2 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,7 +1,7 @@ use super::*; use crate::symbol_resolver::SymbolValue; use crate::toplevel::helper::PrimDef; -use crate::typecheck::typedef::VarMap; +use crate::typecheck::typedef::{GenericObjectType, VarMap}; use nac3parser::ast::Constant; #[derive(Clone, Debug)] @@ -267,12 +267,7 @@ pub fn parse_ast_to_type_annotation_kinds( slice.as_ref(), locked, )?; - let id = - if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { - *obj_id - } else { - unreachable!() - }; + let id = primitives.option.obj_id(unifier); Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] }) } diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index f2b995e2..1b6ded81 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,9 +1,9 @@ use crate::symbol_resolver::SymbolValue; use crate::toplevel::helper::PrimDef; -use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; +use crate::toplevel::primitive_type; use crate::typecheck::{ type_inferencer::*, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, Unifier, VarMap}, }; use itertools::Itertools; use nac3parser::ast::StrRef; @@ -369,8 +369,12 @@ pub fn typeof_ndarray_broadcast( if is_left_ndarray && is_right_ndarray { // Perform broadcasting on two ndarray operands. - let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left); - let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right); + let left_ty = primitive_type::NDArrayType::create(left, unifier); + let left_ty_dtype = left_ty.dtype_tvar(unifier).ty; + let left_ty_ndims = left_ty.ndims_tvar(unifier).ty; + let right_ty = primitive_type::NDArrayType::create(right, unifier); + let right_ty_dtype = right_ty.dtype_tvar(unifier).ty; + let right_ty_ndims = right_ty.ndims_tvar(unifier).ty; assert!(unifier.unioned(left_ty_dtype, right_ty_dtype)); @@ -397,11 +401,18 @@ pub fn typeof_ndarray_broadcast( .collect_vec(); let res_ndims = unifier.get_fresh_literal(res_ndims, None); - Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) + Ok(primitive_type::NDArrayType::from_primitive( + unifier, + primitives, + Some(left_ty_dtype), + Some(res_ndims), + ) + .into()) } else { let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) }; - let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); + let ndarray_ty_dtype = + primitive_type::NDArrayType::create(ndarray_ty, unifier).ndims_tvar(unifier).ty; if unifier.unioned(ndarray_ty_dtype, scalar_ty) { Ok(ndarray_ty) @@ -444,7 +455,8 @@ pub fn typeof_binop( } Operator::MatMult => { - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = + primitive_type::NDArrayType::create(lhs, unifier).ndims_tvar(unifier).ty; let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { TypeEnum::TLiteral { values, .. } => { assert_eq!(values.len(), 1); @@ -452,7 +464,8 @@ pub fn typeof_binop( } _ => unreachable!(), }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = + primitive_type::NDArrayType::create(rhs, unifier).ndims_tvar(unifier).ty; let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { TypeEnum::TLiteral { values, .. } => { assert_eq!(values.len(), 1); @@ -526,7 +539,7 @@ pub fn typeof_unaryop( let operand_obj_id = operand.obj_id(unifier); if op == Unaryop::Not - && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) + && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier)) { return Err( "The truth value of an array with more than one element is ambiguous".to_string() @@ -552,7 +565,8 @@ pub fn typeof_unaryop( Unaryop::UAdd | Unaryop::USub => { if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) { - let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); + let dtype = + primitive_type::NDArrayType::create(operand, unifier).dtype_tvar(unifier).ty; if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) { return Err(if op == Unaryop::UAdd { "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() @@ -586,9 +600,15 @@ pub fn typeof_cmpop( Ok(Some(if is_left_ndarray || is_right_ndarray { let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; - let (_, ndims) = unpack_ndarray_var_tys(unifier, brd); + let ndims = primitive_type::NDArrayType::create(brd, unifier).ndims_tvar(unifier).ty; - make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims)) + primitive_type::NDArrayType::from_primitive( + unifier, + primitives, + Some(primitives.bool), + Some(ndims), + ) + .into() } else if unifier.unioned(lhs, rhs) { primitives.bool } else { @@ -611,64 +631,108 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { - let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None); - impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None); - impl_pow(unifier, store, t, &[t, ndarray_int_t], None); + let ndarray_int_t = + primitive_type::NDArrayType::from_primitive(unifier, store, Some(t), None); + impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t.into()], None); + impl_pow(unifier, store, t, &[t, ndarray_int_t.into()], None); impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_shift(unifier, store, t); - impl_div(unifier, store, t, &[t, ndarray_int_t], None); - impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None); - impl_mod(unifier, store, t, &[t, ndarray_int_t], None); + impl_div(unifier, store, t, &[t, ndarray_int_t.into()], None); + impl_floordiv(unifier, store, t, &[t, ndarray_int_t.into()], None); + impl_mod(unifier, store, t, &[t, ndarray_int_t.into()], None); impl_invert(unifier, store, t, Some(t)); impl_not(unifier, store, t, Some(bool_t)); - impl_comparison(unifier, store, t, &[t, ndarray_int_t], None); - impl_eq(unifier, store, t, &[t, ndarray_int_t], None); + impl_comparison(unifier, store, t, &[t, ndarray_int_t.into()], None); + impl_eq(unifier, store, t, &[t, ndarray_int_t.into()], None); } for t in [int32_t, int64_t] { impl_sign(unifier, store, t, Some(t)); } /* float ======== */ - let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); - let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None); - impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None); - impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None); - impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None); - impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None); - impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None); + let ndarray_float_t = + primitive_type::NDArrayType::from_primitive(unifier, store, Some(float_t), None); + let ndarray_int32_t = + primitive_type::NDArrayType::from_primitive(unifier, store, Some(int32_t), None); + impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None); + impl_pow( + unifier, + store, + float_t, + &[int32_t, float_t, ndarray_int32_t.into(), ndarray_float_t.into()], + None, + ); + impl_div(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None); + impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None); + impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None); impl_sign(unifier, store, float_t, Some(float_t)); impl_not(unifier, store, float_t, Some(bool_t)); - impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None); - impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None); + impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None); /* bool ======== */ - let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None); + let ndarray_bool_t = + primitive_type::NDArrayType::from_primitive(unifier, store, Some(bool_t), None); impl_invert(unifier, store, bool_t, Some(int32_t)); impl_not(unifier, store, bool_t, Some(bool_t)); impl_sign(unifier, store, bool_t, Some(int32_t)); - impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); + impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t.into()], None); /* ndarray ===== */ let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); - let ndarray_unsized_t = - make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty)); - let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); - let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); + let ndarray_unsized_t = primitive_type::NDArrayType::from_primitive( + unifier, + store, + None, + Some(ndarray_usized_ndims_tvar.ty), + ); + let ndarray_dtype_t = ndarray_t.dtype_tvar(unifier).ty; + let ndarray_unsized_dtype_t = ndarray_unsized_t.dtype_tvar(unifier).ty; impl_basic_arithmetic( unifier, store, - ndarray_t, - &[ndarray_unsized_t, ndarray_unsized_dtype_t], + ndarray_t.into(), + &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t], + None, + ); + impl_pow( + unifier, + store, + ndarray_t.into(), + &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t], + None, + ); + impl_div(unifier, store, ndarray_t.into(), &[ndarray_t.into(), ndarray_dtype_t], None); + impl_floordiv( + unifier, + store, + ndarray_t.into(), + &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t], + None, + ); + impl_mod( + unifier, + store, + ndarray_t.into(), + &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t], + None, + ); + impl_matmul(unifier, store, ndarray_t.into(), &[ndarray_t.into()], Some(ndarray_t.into())); + impl_sign(unifier, store, ndarray_t.into(), Some(ndarray_t.into())); + impl_invert(unifier, store, ndarray_t.into(), Some(ndarray_t.into())); + impl_eq( + unifier, + store, + ndarray_t.into(), + &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t], + None, + ); + impl_comparison( + unifier, + store, + ndarray_t.into(), + &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t], None, ); - impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); - impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); - impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); - impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); - impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index bb4ab244..784056d0 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -4,14 +4,16 @@ use std::iter::once; use std::ops::Not; use std::{cell::RefCell, sync::Arc}; -use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; +use super::typedef::{ + Call, FunSignature, FuncArg, GenericObjectType, RecordField, Type, TypeEnum, Unifier, VarMap, +}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; +use crate::toplevel::primitive_type::{NDArrayType, OptionType}; use crate::toplevel::TopLevelDef; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, }; @@ -49,8 +51,8 @@ pub struct PrimitiveStore { pub range: Type, pub str: Type, pub exception: Type, - pub option: Type, - pub ndarray: Type, + pub option: OptionType, + pub ndarray: NDArrayType, pub size_t: u32, } @@ -97,8 +99,8 @@ impl IntoIterator for &PrimitiveStore { self.range, self.str, self.exception, - self.option, - self.ndarray, + self.option.into(), + self.ndarray.into(), ] .into_iter() } @@ -528,7 +530,7 @@ impl<'a> Fold<()> for Inferencer<'a> { // the name `none` is special since it may have different types if id == &"none".into() { if let TypeEnum::TObj { params, .. } = - self.unifier.get_ty_immutable(self.primitives.option).as_ref() + &*self.unifier.get_ty_immutable(self.primitives.option.into()) { let var_map = params .iter() @@ -543,7 +545,7 @@ impl<'a> Fold<()> for Inferencer<'a> { (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty) }) .collect::(); - Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) + Some(self.unifier.subst(self.primitives.option.into(), &var_map).unwrap()) } else { unreachable!("must be tobj") } @@ -1062,9 +1064,16 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + let ndarray_ndims = + NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty; - make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) + NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(target_ty), + Some(ndarray_ndims), + ) + .into() } else { target_ty }; @@ -1100,9 +1109,7 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); - - ndarray_dtype + NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty } else { arg0_ty }; @@ -1154,14 +1161,14 @@ impl<'a> Inferencer<'a> { let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - unpack_ndarray_var_tys(self.unifier, arg0_ty).0 + NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty } else { arg0_ty }; let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - unpack_ndarray_var_tys(self.unifier, arg1_ty).0 + NDArrayType::create(arg1_ty, self.unifier).dtype_tvar(self.unifier).ty } else { arg1_ty }; @@ -1192,9 +1199,17 @@ impl<'a> Inferencer<'a> { // (float, int32), so convert it to align with the dtype of the first arg let arg1_ty = if id == &"np_ldexp".into() { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); + // let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); + let ndims = + NDArrayType::create(arg1_ty, self.unifier).ndims_tvar(self.unifier).ty; - make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) + NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(target_ty), + Some(ndims), + ) + .into() } else { target_ty } @@ -1281,9 +1296,16 @@ impl<'a> Inferencer<'a> { let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + let ndarray_ndims = + NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty; - make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) + NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(target_ty), + Some(ndarray_ndims), + ) + .into() } else { target_ty }; @@ -1323,7 +1345,7 @@ impl<'a> Inferencer<'a> { self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape` let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); - let ret = make_ndarray_ty( + let ret = NDArrayType::from_primitive( self.unifier, self.primitives, Some(self.primitives.float), @@ -1335,13 +1357,13 @@ impl<'a> Inferencer<'a> { ty: shape.custom.unwrap(), default_value: None, }], - ret, + ret: ret.into(), vars: VarMap::new(), })); return Ok(Some(Located { location, - custom: Some(ret), + custom: Some(ret.into()), node: ExprKind::Call { func: Box::new(Located { custom: Some(custom), @@ -1374,7 +1396,8 @@ impl<'a> Inferencer<'a> { let ty = arg1.custom.unwrap(); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); - let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); + let ret = + NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims)); let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None }, @@ -1384,13 +1407,13 @@ impl<'a> Inferencer<'a> { default_value: None, }, ], - ret, + ret: ret.into(), vars: VarMap::new(), })); return Ok(Some(Located { location, - custom: Some(ret), + custom: Some(ret.into()), node: ExprKind::Call { func: Box::new(Located { custom: Some(custom), @@ -1428,7 +1451,8 @@ impl<'a> Inferencer<'a> { arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) }; let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); - let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); + let ret = + NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims)); let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ @@ -1448,13 +1472,13 @@ impl<'a> Inferencer<'a> { default_value: Some(SymbolValue::U32(0)), }, ], - ret, + ret: ret.into(), vars: VarMap::new(), })); return Ok(Some(Located { location, - custom: Some(ret), + custom: Some(ret.into()), node: ExprKind::Call { func: Box::new(Located { custom: Some(custom), @@ -1803,9 +1827,13 @@ impl<'a> Inferencer<'a> { TypeEnum::TVar { is_const_generic: false, .. } )); - let constrained_ty = - make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims)); - self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; + let constrained_ty = NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(dummy_tvar), + Some(ndims), + ); + self.constrain(value.custom.unwrap(), constrained_ty.into(), &value.location)?; let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else { panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims)) @@ -1871,10 +1899,14 @@ impl<'a> Inferencer<'a> { let ndims_ty = self .unifier .get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None); - let subscripted_ty = - make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty)); + let subscripted_ty = NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(dummy_tvar), + Some(ndims_ty), + ); - Ok(subscripted_ty) + Ok(subscripted_ty.into()) } } @@ -1893,10 +1925,17 @@ impl<'a> Inferencer<'a> { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier) + .ndims_tvar(self.unifier) + .ty; - make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) + NDArrayType::from_primitive( + self.unifier, + self.primitives, + Some(ty), + Some(ndims), + ) + .into() } _ => unreachable!(), @@ -1907,8 +1946,10 @@ impl<'a> Inferencer<'a> { ExprKind::Constant { value: ast::Constant::Int(val), .. } => { match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier) + .ndims_tvar(self.unifier) + .ty; + self.infer_subscript_ndarray(value, slice, ty, ndims) } _ => { @@ -1951,7 +1992,10 @@ impl<'a> Inferencer<'a> { } } - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier) + .ndims_tvar(self.unifier) + .ty; + self.infer_subscript_ndarray(value, slice, ty, ndims) } _ => { @@ -1975,8 +2019,9 @@ impl<'a> Inferencer<'a> { Ok(ty) } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier) + .ndims_tvar(self.unifier) + .ty; let valid_index_tys = [self.primitives.int32, self.primitives.isize()] .into_iter() diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 9096ae56..7dd384e9 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -139,6 +139,7 @@ impl TestEnvironment { fields: HashMap::new(), params: VarMap::new(), }); + let option = OptionType::create(option, &mut unifier); let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); @@ -147,6 +148,7 @@ impl TestEnvironment { fields: HashMap::new(), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), }); + let ndarray = NDArrayType::create(ndarray, &mut unifier); let primitives = PrimitiveStore { int32, int64, @@ -273,11 +275,13 @@ impl TestEnvironment { fields: HashMap::new(), params: VarMap::new(), }); + let option = OptionType::create(option, &mut unifier); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), params: VarMap::new(), }); + let ndarray = NDArrayType::create(ndarray, &mut unifier); identifier_mapping.insert("None".into(), none); for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] .iter() diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index d268455b..bd20fc85 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -137,7 +137,8 @@ where #[must_use] fn get_type(&self) -> Type; - /// See [`Type::obj_id`]. + /// Similar to [`Type::obj_id`], except that the [`DefinitionId`] is not wrapped within an + /// [`Option`]. #[must_use] fn obj_id(&self, unifier: &Unifier) -> DefinitionId { self.get_type().obj_id(unifier).unwrap()