From fd36f78005b2be23f2ed34852dc12ce3f174ab69 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 12 Jun 2024 15:01:01 +0800 Subject: [PATCH] core: refactor `PrimitiveDefinitionId` into enum `PrimDef` --- Cargo.lock | 21 ++ nac3artiq/src/codegen.rs | 4 +- nac3artiq/src/symbol_resolver.rs | 6 +- nac3core/Cargo.toml | 2 + nac3core/src/codegen/builtin_fns.rs | 142 ++++++------- nac3core/src/codegen/expr.rs | 43 ++-- nac3core/src/codegen/mod.rs | 10 +- nac3core/src/codegen/numpy.rs | 4 +- nac3core/src/codegen/stmt.rs | 8 +- nac3core/src/toplevel/builtins.rs | 42 ++-- nac3core/src/toplevel/helper.rs | 187 +++++++++++------- nac3core/src/toplevel/numpy.rs | 6 +- nac3core/src/toplevel/type_annotation.rs | 7 +- nac3core/src/typecheck/magic_methods.rs | 33 ++-- nac3core/src/typecheck/type_inferencer/mod.rs | 47 ++--- .../src/typecheck/type_inferencer/test.rs | 50 ++--- 16 files changed, 330 insertions(+), 282 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d92fe93f1..5aa1fde73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -625,6 +625,8 @@ dependencies = [ "parking_lot", "rayon", "regex", + "strum", + "strum_macros", "test-case", ] @@ -1116,6 +1118,25 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.66", +] + [[package]] name = "syn" version = "1.0.109" diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 74e552070..08608efc7 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -6,7 +6,7 @@ use nac3core::{ CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, - toplevel::{helper::PRIMITIVE_DEF_IDS, DefinitionId, GenCall}, + toplevel::{helper::PrimDef, DefinitionId, GenCall}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap}, }; @@ -683,7 +683,7 @@ pub fn attributes_writeback( let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); if let Err(e) = - rpc_codegen_callback_fn(ctx, None, (&fun, PRIMITIVE_DEF_IDS.int32), args, generator) + rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator) { return Ok(Err(e)); } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 9f4118e8e..3c54d3c21 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -3,7 +3,7 @@ use nac3core::{ codegen::{CodeGenContext, CodeGenerator}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{ - helper::PRIMITIVE_DEF_IDS, + helper::PrimDef, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, TopLevelDef, }, @@ -469,7 +469,7 @@ impl InnerResolver { ))); } } - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { if args.len() != 2 { return Ok(Err(format!( "type list needs exactly 2 type parameters, found {}", @@ -660,7 +660,7 @@ impl InnerResolver { } } } - (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 6ae73342e..724e0c8cc 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -11,6 +11,8 @@ indexmap = "2.2" parking_lot = "0.12" rayon = "1.8" nac3parser = { path = "../nac3parser" } +strum = "0.26.2" +strum_macros = "0.26.4" [dependencies.inkwell] version = "0.4" diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 0fdd50cc4..0e4b75f46 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -7,7 +7,7 @@ use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; -use crate::toplevel::helper::PRIMITIVE_DEF_IDS; +use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::Type; @@ -64,7 +64,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -126,7 +126,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -204,7 +204,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -271,7 +271,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -337,7 +337,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -383,7 +383,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -423,7 +423,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -488,7 +488,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -542,7 +542,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -592,7 +592,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -690,7 +690,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); @@ -783,13 +783,13 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -906,7 +906,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); @@ -999,13 +999,13 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -1086,7 +1086,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); @@ -1126,7 +1126,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1170,7 +1170,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1214,7 +1214,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1254,7 +1254,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1294,7 +1294,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1334,7 +1334,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1374,7 +1374,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1414,7 +1414,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1454,7 +1454,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1494,7 +1494,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1534,7 +1534,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1574,7 +1574,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1614,7 +1614,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1654,7 +1654,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1694,7 +1694,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1734,7 +1734,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1774,7 +1774,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1814,7 +1814,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1854,7 +1854,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1894,7 +1894,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1934,7 +1934,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -1974,7 +1974,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -2014,7 +2014,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -2054,7 +2054,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -2094,7 +2094,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(z) - if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); @@ -2134,7 +2134,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -2174,7 +2174,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(z) - if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); @@ -2214,7 +2214,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -2254,7 +2254,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -2294,7 +2294,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::PointerValue(x) - if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => + if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); @@ -2336,13 +2336,13 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2403,13 +2403,13 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2470,13 +2470,13 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2537,13 +2537,13 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2604,13 +2604,13 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; @@ -2660,13 +2660,13 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2727,13 +2727,13 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) => { let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index b898cc2b8..7a452032c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -16,7 +16,7 @@ use crate::{ }, symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{ - helper::PRIMITIVE_DEF_IDS, + helper::PrimDef, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, TopLevelDef, }, @@ -1181,15 +1181,13 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some("f_pow_i"), ); Ok(Some(res.into())) - } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) - || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) + || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); - let is_ndarray1 = - ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_ndarray2 = - ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); @@ -1427,7 +1425,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( .unwrap(), _ => val.into(), } - } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); @@ -1435,16 +1433,15 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function - let op = - if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { - if *op == ast::Unaryop::Invert { - &ast::Unaryop::Not - } else { - unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) - } + let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { + if *op == ast::Unaryop::Invert { + &ast::Unaryop::Not } else { - op - }; + unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op)) + } + } else { + op + }; let res = numpy::ndarray_elementwise_unaryop_impl( generator, @@ -1499,8 +1496,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let left_ty = ctx.unifier.get_representative(left.0.unwrap()); let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap()); - if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) - || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) + || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); @@ -1509,9 +1506,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let op = ops[0].clone(); let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); return if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); @@ -2370,7 +2367,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( )) } Some(v) => Ok(Some(v)), - }; + } } ValueEnum::Dynamic(BasicValueEnum::PointerValue(ptr)) => { let not_null = @@ -2518,7 +2515,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v.data().get(ctx, generator, &index, None).into() } } - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let v = if let Some(v) = generator.gen_expr(ctx, value)? { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index de8a33467..517c5c9c3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,9 +1,7 @@ use crate::{ codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{ - helper::PRIMITIVE_DEF_IDS, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef, - }, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, @@ -437,9 +435,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( let result = match &*ty_enum { TObj { obj_id, fields, .. } => { // check to avoid treating non-class primitives as classes - if obj_id.0 <= PRIMITIVE_DEF_IDS.max_id().0 { + if PrimDef::contains_id(*obj_id) { return match &*unifier.get_ty_immutable(ty) { - TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.option => { + TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => { get_llvm_type( ctx, module, @@ -453,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( .into() } - TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let element_type = get_llvm_type( ctx, module, generator, unifier, top_level, type_cache, dtype, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 30eb3b82c..a7e032cec 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -18,7 +18,7 @@ use crate::{ }, symbol_resolver::ValueEnum, toplevel::{ - helper::PRIMITIVE_DEF_IDS, + helper::PrimDef, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, @@ -1775,7 +1775,7 @@ pub fn gen_ndarray_array<'ctx>( let obj_ty = fun.0.args[0].ty; let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 } diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index f14709c87..0448cba1a 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -10,9 +10,7 @@ use crate::{ expr::gen_binop_expr, gen_in_range_check, }, - toplevel::{ - helper::PRIMITIVE_DEF_IDS, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, typecheck::typedef::{FunSignature, Type, TypeEnum}, }; use inkwell::{ @@ -188,7 +186,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( v.data().ptr_offset(ctx, generator, &index, name) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { todo!() } @@ -246,7 +244,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let value = ListValue::from_ptr_val(value, llvm_usize, None); let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { TypeEnum::TList { ty } => *ty, - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 } _ => unreachable!(), diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 2d5e68ceb..fdc157e34 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -22,7 +22,7 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::{helper::PRIMITIVE_DEF_IDS, numpy::make_ndarray_ty}, + toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, typecheck::typedef::VarMap, }; @@ -90,7 +90,7 @@ pub fn get_exn_constructor( methods: vec![("__init__".into(), signature, DefinitionId(cons_id))], ancestors: vec![ TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() }, - TypeAnnotation::CustomClass { id: PRIMITIVE_DEF_IDS.exception, params: Vec::default() }, + TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() }, ], constructor: Some(signature), resolver: None, @@ -365,49 +365,49 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.int32, + PrimDef::Int32.id(), None, "int32".into(), None, None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.int64, + PrimDef::Int64.id(), None, "int64".into(), None, None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.float, + PrimDef::Float.id(), None, "float".into(), None, None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.bool, + PrimDef::Bool.id(), None, "bool".into(), None, None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.none, + PrimDef::None.id(), None, "none".into(), None, None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.range, + PrimDef::Range.id(), None, "range".into(), None, None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.str, + PrimDef::Str.id(), None, "str".into(), None, @@ -415,7 +415,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built ))), Arc::new(RwLock::new(TopLevelDef::Class { name: "Exception".into(), - object_id: PRIMITIVE_DEF_IDS.exception, + object_id: PrimDef::Exception.id(), type_vars: Vec::default(), fields: exception_fields, methods: Vec::default(), @@ -425,14 +425,14 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built loc: None, })), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.uint32, + PrimDef::UInt32.id(), None, "uint32".into(), None, None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PRIMITIVE_DEF_IDS.uint64, + PrimDef::UInt64.id(), None, "uint64".into(), None, @@ -441,16 +441,16 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built Arc::new(RwLock::new({ TopLevelDef::Class { name: "Option".into(), - object_id: PRIMITIVE_DEF_IDS.option, + object_id: PrimDef::Option.id(), type_vars: vec![option_ty_var], fields: vec![], methods: vec![ - ("is_some".into(), is_some_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.option.0 + 1)), - ("is_none".into(), is_some_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.option.0 + 2)), - ("unwrap".into(), unwrap_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.option.0 + 3)), + ("is_some".into(), is_some_ty.0, PrimDef::OptionIsSome.id()), + ("is_none".into(), is_some_ty.0, PrimDef::OptionIsNone.id()), + ("unwrap".into(), unwrap_ty.0, PrimDef::OptionUnwrap.id()), ], ancestors: vec![TypeAnnotation::CustomClass { - id: PRIMITIVE_DEF_IDS.option, + id: PrimDef::Option.id(), params: Vec::default(), }], constructor: None, @@ -517,12 +517,12 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built })), Arc::new(RwLock::new(TopLevelDef::Class { name: "ndarray".into(), - object_id: PRIMITIVE_DEF_IDS.ndarray, + object_id: PrimDef::NDArray.id(), type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty], fields: Vec::default(), methods: vec![ - ("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)), - ("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)), + ("copy".into(), ndarray_copy_ty.0, PrimDef::NDArrayCopy.id()), + ("fill".into(), ndarray_fill_ty.0, PrimDef::NDArrayFill.id()), ], ancestors: Vec::default(), constructor: None, @@ -1317,7 +1317,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built } } TypeEnum::TObj { obj_id, .. } - if *obj_id == PRIMITIVE_DEF_IDS.ndarray => + if *obj_id == PrimDef::NDArray.id() => { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index f1f72fe8c..172136376 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -4,76 +4,121 @@ use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; +use strum::IntoEnumIterator; +use strum_macros::EnumIter; use super::*; -/// Structure storing [`DefinitionId`] for primitive types. -#[derive(Clone, Copy)] -pub struct PrimitiveDefinitionIds { - pub int32: DefinitionId, - pub int64: DefinitionId, - pub uint32: DefinitionId, - pub uint64: DefinitionId, - pub float: DefinitionId, - pub bool: DefinitionId, - pub none: DefinitionId, - pub range: DefinitionId, - pub str: DefinitionId, - pub exception: DefinitionId, - pub option: DefinitionId, - pub ndarray: DefinitionId, +/// All primitive types and functions in nac3core. +#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)] +pub enum PrimDef { + Int32, + Int64, + Float, + Bool, + None, + Range, + Str, + Exception, + UInt32, + UInt64, + Option, + OptionIsSome, + OptionIsNone, + OptionUnwrap, + NDArray, + NDArrayCopy, + NDArrayFill, + FunInt32, + FunInt64, + FunUInt32, + FunUInt64, + FunFloat, + FunNpNDArray, + FunNpEmpty, + FunNpZeros, + FunNpOnes, + FunNpFull, + FunNpArray, + FunNpEye, + FunNpIdentity, + FunRound, + FunRound64, + FunNpRound, + FunRange, + FunStr, + FunBool, + FunFloor, + FunFloor64, + FunNpFloor, + FunCeil, + FunCeil64, + FunNpCeil, + FunLen, + FunMin, + FunNpMin, + FunNpMinimum, + FunMax, + FunNpMax, + FunNpMaximum, + FunAbs, + FunNpIsNan, + FunNpIsInf, + FunNpSin, + FunNpCos, + FunNpExp, + FunNpExp2, + FunNpLog, + FunNpLog10, + FunNpLog2, + FunNpFabs, + FunNpSqrt, + FunNpRint, + FunNpTan, + FunNpArcsin, + FunNpArccos, + FunNpArctan, + FunNpSinh, + FunNpCosh, + FunNpTanh, + FunNpArcsinh, + FunNpArccosh, + FunNpArctanh, + FunNpExpm1, + FunNpCbrt, + FunSpSpecErf, + FunSpSpecErfc, + FunSpSpecGamma, + FunSpSpecGammaln, + FunSpSpecJ0, + FunSpSpecJ1, + FunNpArctan2, + FunNpCopysign, + FunNpFmax, + FunNpFmin, + FunNpLdExp, + FunNpHypot, + FunNpNextAfter, + FunSome, } -impl PrimitiveDefinitionIds { - /// Returns all [`DefinitionId`] of primitives as a [`Vec`]. +impl PrimDef { + /// Get the assigned [`DefinitionId`] of this [`PrimDef`]. /// - /// There are no guarantees on ordering of the IDs. + /// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at, + /// with the first `PrimDef`'s definition id being `0`. #[must_use] - fn as_vec(&self) -> Vec { - vec![ - self.int32, - self.int64, - self.uint32, - self.uint64, - self.float, - self.bool, - self.none, - self.range, - self.str, - self.exception, - self.option, - self.ndarray, - ] + pub fn id(&self) -> DefinitionId { + DefinitionId(*self as usize) } - /// Returns an iterator over all [`DefinitionId`]s of this instance in indeterminate order. - pub fn iter(&self) -> impl Iterator { - self.as_vec().into_iter() - } - - /// Returns the primitive with the largest [`DefinitionId`]. + /// Check if a definition ID is that of a [`PrimDef`]. #[must_use] - pub fn max_id(&self) -> DefinitionId { - self.iter().max().unwrap() + pub fn contains_id(id: DefinitionId) -> bool { + Self::iter().any(|prim| prim.id() == id) } } -/// The [definition IDs][DefinitionId] for primitive types. -pub const PRIMITIVE_DEF_IDS: PrimitiveDefinitionIds = PrimitiveDefinitionIds { - int32: DefinitionId(0), - int64: DefinitionId(1), - uint32: DefinitionId(8), - uint64: DefinitionId(9), - float: DefinitionId(2), - bool: DefinitionId(3), - none: DefinitionId(4), - range: DefinitionId(5), - str: DefinitionId(6), - exception: DefinitionId(7), - option: DefinitionId(10), - ndarray: DefinitionId(14), -}; - impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { @@ -116,42 +161,42 @@ impl TopLevelComposer { pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.int32, + obj_id: PrimDef::Int32.id(), fields: HashMap::new(), params: VarMap::new(), }); let int64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.int64, + obj_id: PrimDef::Int64.id(), fields: HashMap::new(), params: VarMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.float, + obj_id: PrimDef::Float.id(), fields: HashMap::new(), params: VarMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.bool, + obj_id: PrimDef::Bool.id(), fields: HashMap::new(), params: VarMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.none, + obj_id: PrimDef::None.id(), fields: HashMap::new(), params: VarMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.range, + obj_id: PrimDef::Range.id(), fields: HashMap::new(), params: VarMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.str, + obj_id: PrimDef::Str.id(), fields: HashMap::new(), params: VarMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.exception, + obj_id: PrimDef::Exception.id(), fields: vec![ ("__name__".into(), (int32, true)), ("__file__".into(), (str, true)), @@ -168,12 +213,12 @@ impl TopLevelComposer { params: VarMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.uint32, + obj_id: PrimDef::UInt32.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.uint64, + obj_id: PrimDef::UInt64.id(), fields: HashMap::new(), params: VarMap::new(), }); @@ -190,7 +235,7 @@ impl TopLevelComposer { vars: VarMap::from([(option_type_var.1, option_type_var.0)]), })); let option = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.option, + obj_id: PrimDef::Option.id(), fields: vec![ ("is_some".into(), (is_some_type_fun_ty, true)), ("is_none".into(), (is_some_type_fun_ty, true)), @@ -232,7 +277,7 @@ impl TopLevelComposer { ]), })); let ndarray = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.ndarray, + obj_id: PrimDef::NDArray.id(), fields: Mapping::from([ ("copy".into(), (ndarray_copy_fun_ty, true)), ("fill".into(), (ndarray_fill_fun_ty, true)), @@ -689,7 +734,7 @@ pub fn parse_parameter_default_value( /// Obtains the element type of an array-like type. pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { match &*unifier.get_ty(ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { unpack_ndarray_var_tys(unifier, ty).0 } @@ -701,7 +746,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { /// Obtains the number of dimensions of an array-like type. pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { match &*unifier.get_ty(ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let ndims = unpack_ndarray_var_tys(unifier, ty).1; let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 99129a9a3..b6e0ca557 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,5 +1,5 @@ use crate::{ - toplevel::helper::PRIMITIVE_DEF_IDS, + toplevel::helper::PrimDef, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, Unifier, VarMap}, @@ -37,7 +37,7 @@ pub fn subst_ndarray_tvars( let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; - debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); + debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); if dtype.is_none() && ndims.is_none() { return ndarray; @@ -61,7 +61,7 @@ fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type) let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; - debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); + debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); debug_assert_eq!(params.len(), 2); params diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 4f42d8d29..582410751 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,6 +1,6 @@ use super::*; use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PRIMITIVE_DEF_IDS; +use crate::toplevel::helper::PrimDef; use crate::typecheck::typedef::VarMap; use nac3parser::ast::Constant; @@ -95,10 +95,7 @@ pub fn parse_ast_to_type_annotation_kinds( } else if id == &"str".into() { Ok(TypeAnnotation::Primitive(primitives.str)) } else if id == &"Exception".into() { - Ok(TypeAnnotation::CustomClass { - id: PRIMITIVE_DEF_IDS.exception, - params: Vec::default(), - }) + Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() }) } else if let Ok(obj_id) = resolver.get_identifier_def(*id) { let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 6f1eeced1..ed27d98b6 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,5 +1,5 @@ use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PRIMITIVE_DEF_IDS; +use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::typecheck::{ type_inferencer::*, @@ -11,6 +11,7 @@ use nac3parser::ast::{Cmpop, Operator, Unaryop}; use std::cmp::max; use std::collections::HashMap; use std::rc::Rc; +use strum::IntoEnumIterator; #[must_use] pub fn binop_name(op: &Operator) -> &'static str { @@ -360,8 +361,8 @@ pub fn typeof_ndarray_broadcast( left: Type, right: Type, ) -> Result { - let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); assert!(is_left_ndarray || is_right_ndarray); @@ -428,8 +429,8 @@ pub fn typeof_binop( lhs: Type, rhs: Type, ) -> Result, String> { - let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); Ok(Some(match op { Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => { @@ -534,17 +535,15 @@ pub fn typeof_unaryop( Ok(match *op { Unaryop::Not => match operand_obj_id { - Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand), + Some(v) if v == PrimDef::NDArray.id() => Some(operand), Some(_) => Some(primitives.bool), _ => None, }, Unaryop::Invert => { - if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) { Some(primitives.int32) - } else if operand_obj_id - .is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) - { + } else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) { Some(operand) } else { None @@ -552,9 +551,9 @@ pub fn typeof_unaryop( } Unaryop::UAdd | Unaryop::USub => { - if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) { let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); - if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) { return Err(if *op == Unaryop::UAdd { "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() } else { @@ -563,11 +562,9 @@ pub fn typeof_unaryop( } Some(operand) - } else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { + } else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) { Some(primitives.int32) - } else if operand_obj_id - .is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) - { + } else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) { Some(operand) } else { None @@ -584,8 +581,8 @@ pub fn typeof_cmpop( lhs: Type, rhs: Type, ) -> Result, String> { - let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); - let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); Ok(Some(if is_left_ndarray || is_right_ndarray { let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0297eaeeb..4c8006741 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -9,7 +9,7 @@ use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ - helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS}, + helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, @@ -244,7 +244,7 @@ impl<'a> Fold<()> for Inferencer<'a> { TypeEnum::TList { .. } => { self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { todo!() } _ => unreachable!(), @@ -892,9 +892,7 @@ impl<'a> Inferencer<'a> { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty - .obj_id(self.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); @@ -932,14 +930,14 @@ impl<'a> Inferencer<'a> { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = - if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { - let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) + { + let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); - ndarray_dtype - } else { - arg0_ty - }; + ndarray_dtype + } else { + arg0_ty + }; let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { @@ -987,14 +985,14 @@ impl<'a> Inferencer<'a> { let arg1_ty = arg1.custom.unwrap(); let arg0_dtype = - if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { unpack_ndarray_var_tys(self.unifier, arg0_ty).0 } else { arg0_ty }; let arg1_dtype = - if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { unpack_ndarray_var_tys(self.unifier, arg1_ty).0 } else { arg1_ty @@ -1020,15 +1018,12 @@ impl<'a> Inferencer<'a> { }; let ret = if [&arg0_ty, &arg1_ty].into_iter().any(|arg_ty| { - arg_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + arg_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) }) { // typeof_ndarray_broadcast requires both dtypes to be the same, but ldexp accepts // (float, int32), so convert it to align with the dtype of the first arg let arg1_ty = if id == &"np_ldexp".into() { - if arg1_ty - .obj_id(self.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) - { + if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) @@ -1116,9 +1111,7 @@ impl<'a> Inferencer<'a> { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty - .obj_id(self.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); @@ -1552,7 +1545,7 @@ impl<'a> Inferencer<'a> { expr.custom .unwrap() .obj_id(self.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + .is_some_and(|id| id == PrimDef::NDArray.id()) }) { return Err(HashSet::from([String::from( @@ -1670,7 +1663,7 @@ impl<'a> Inferencer<'a> { } let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); @@ -1684,7 +1677,7 @@ impl<'a> Inferencer<'a> { } ExprKind::Constant { value: ast::Constant::Int(val), .. } => { match &*self.unifier.get_ty(value.custom.unwrap()) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.infer_subscript_ndarray(value, ty, ndims) @@ -1710,7 +1703,7 @@ impl<'a> Inferencer<'a> { .custom .unwrap() .obj_id(self.unifier) - .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + .is_some_and(|id| id == PrimDef::NDArray.id()) .not() { return report_error( @@ -1755,7 +1748,7 @@ impl<'a> Inferencer<'a> { self.constrain(value.custom.unwrap(), list, &value.location)?; Ok(ty) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 13684e769..f9fec50a0 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -3,7 +3,7 @@ use super::*; use crate::{ codegen::CodeGenContext, symbol_resolver::ValueEnum, - toplevel::{helper::PRIMITIVE_DEF_IDS, DefinitionId, TopLevelDef}, + toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, }; use indoc::indoc; use nac3parser::parser::parse_program; @@ -75,7 +75,7 @@ impl TestEnvironment { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.int32, + obj_id: PrimDef::Int32.id(), fields: HashMap::new(), params: VarMap::new(), }); @@ -88,52 +88,52 @@ impl TestEnvironment { fields.insert("__add__".into(), (add_ty, false)); }); let int64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.int64, + obj_id: PrimDef::Int64.id(), fields: HashMap::new(), params: VarMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.float, + obj_id: PrimDef::Float.id(), fields: HashMap::new(), params: VarMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.bool, + obj_id: PrimDef::Bool.id(), fields: HashMap::new(), params: VarMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.none, + obj_id: PrimDef::None.id(), fields: HashMap::new(), params: VarMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.range, + obj_id: PrimDef::Range.id(), fields: HashMap::new(), params: VarMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.str, + obj_id: PrimDef::Str.id(), fields: HashMap::new(), params: VarMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.exception, + obj_id: PrimDef::Exception.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.uint32, + obj_id: PrimDef::UInt32.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.uint64, + obj_id: PrimDef::UInt64.id(), fields: HashMap::new(), params: VarMap::new(), }); let option = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.option, + obj_id: PrimDef::Option.id(), fields: HashMap::new(), params: VarMap::new(), }); @@ -141,7 +141,7 @@ impl TestEnvironment { let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); let ndarray = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.ndarray, + obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), params: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), @@ -214,7 +214,7 @@ impl TestEnvironment { let mut identifier_mapping = HashMap::new(); let mut top_level_defs: Vec>> = Vec::new(); let int32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.int32, + obj_id: PrimDef::Int32.id(), fields: HashMap::new(), params: VarMap::new(), }); @@ -227,57 +227,57 @@ impl TestEnvironment { fields.insert("__add__".into(), (add_ty, false)); }); let int64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.int64, + obj_id: PrimDef::Int64.id(), fields: HashMap::new(), params: VarMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.float, + obj_id: PrimDef::Float.id(), fields: HashMap::new(), params: VarMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.bool, + obj_id: PrimDef::Bool.id(), fields: HashMap::new(), params: VarMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.none, + obj_id: PrimDef::None.id(), fields: HashMap::new(), params: VarMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.range, + obj_id: PrimDef::Range.id(), fields: HashMap::new(), params: VarMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.str, + obj_id: PrimDef::Str.id(), fields: HashMap::new(), params: VarMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.exception, + obj_id: PrimDef::Exception.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.uint32, + obj_id: PrimDef::UInt32.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.uint64, + obj_id: PrimDef::UInt64.id(), fields: HashMap::new(), params: VarMap::new(), }); let option = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.option, + obj_id: PrimDef::Option.id(), fields: HashMap::new(), params: VarMap::new(), }); let ndarray = unifier.add_ty(TypeEnum::TObj { - obj_id: PRIMITIVE_DEF_IDS.ndarray, + obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), params: VarMap::new(), });