From 13f06f3e290ebdc509c8ea0823cc6d42146088a2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 14 Mar 2024 13:21:56 +0800 Subject: [PATCH] core: Refactor VarMap to IndexMap This is the only Map I can find that preserves insertion order while also deduplicating elements by key. --- Cargo.lock | 1 + nac3core/Cargo.toml | 1 + nac3core/src/codegen/concrete_type.rs | 3 +- nac3core/src/codegen/expr.rs | 3 +- nac3core/src/codegen/test.rs | 3 +- nac3core/src/toplevel/composer.rs | 5 ++-- nac3core/src/toplevel/numpy.rs | 1 - ...est__test_analyze__list_tuple_generic.snap | 2 +- ...__toplevel__test__test_analyze__self1.snap | 10 +++---- nac3core/src/toplevel/type_annotation.rs | 1 + nac3core/src/typecheck/typedef/mod.rs | 29 ++++++++----------- nac3core/src/typecheck/typedef/test.rs | 6 ++-- 12 files changed, 30 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9de4130f..0b6ac393 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -616,6 +616,7 @@ name = "nac3core" version = "0.1.0" dependencies = [ "crossbeam", + "indexmap 2.2.5", "indoc", "inkwell", "insta", diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index d964e5a0..dc81b33a 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] itertools = "0.12" crossbeam = "0.8" +indexmap = "2.2" parking_lot = "0.12" rayon = "1.8" nac3parser = { path = "../nac3parser" } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 30de4400..87ced722 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -9,6 +9,7 @@ use crate::{ use nac3parser::ast::StrRef; use std::collections::HashMap; +use indexmap::IndexMap; pub struct ConcreteTypeStore { store: Vec, @@ -50,7 +51,7 @@ pub enum ConcreteTypeEnum { TObj { obj_id: DefinitionId, fields: HashMap, - params: HashMap, + params: IndexMap, }, TVirtual { ty: ConcreteType, diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 703fe164..c110d0c6 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -59,7 +59,7 @@ pub fn get_subst_key( params.clone() }) .unwrap_or_default(); - vars.extend(fun_vars.iter()); + vars.extend(fun_vars); let sorted = vars.keys().filter(|id| filter.map_or(true, |v| v.contains(id))).sorted(); sorted .map(|id| { @@ -1983,7 +1983,6 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { let (ty, ndims) = params.iter() - .sorted_by_key(|(var_id, _)| *var_id) .map(|(_, ty)| ty) .collect_tuple() .unwrap(); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 159ac2b4..1810a18b 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -10,7 +10,7 @@ use crate::{ }, typecheck::{ type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, + typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }, }; use indoc::indoc; @@ -25,7 +25,6 @@ use nac3parser::{ use parking_lot::RwLock; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use crate::typecheck::typedef::VarMap; struct Resolver { id_to_type: HashMap, diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index c51edaeb..e1c800b9 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1926,9 +1926,8 @@ impl TopLevelComposer { ret_str, name, ast.as_ref().unwrap().location - ), - ])) - } + ),])) + } instance_to_stmt.insert( get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())), diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index f9ac2eb1..d3225192 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -28,7 +28,6 @@ pub fn make_ndarray_ty( let tvar_ids = params.iter() .map(|(obj_id, _)| *obj_id) - .sorted() .collect_vec(); debug_assert_eq!(tvar_ids.len(), 2); diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index c468b5f3..7d247935 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -7,7 +7,7 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n", - "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 07fa75e3..e32ae809 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,11 +3,11 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", - "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n", + "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index ddcf9469..e6769a03 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,5 +1,6 @@ use crate::symbol_resolver::SymbolValue; use crate::toplevel::helper::PRIMITIVE_DEF_IDS; +use crate::typecheck::typedef::VarMap; use super::*; use nac3parser::ast::Constant; diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index f9eb73f4..c23b3d84 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,10 +1,11 @@ use std::cell::RefCell; -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; use std::fmt::Display; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::{borrow::Cow, collections::HashSet}; use std::iter::zip; +use indexmap::IndexMap; use itertools::Itertools; use nac3parser::ast::{Location, StrRef}; @@ -25,14 +26,10 @@ pub type Type = UnificationKey; pub struct CallId(pub(super) usize); pub type Mapping = HashMap; +pub type IndexMapping = IndexMap; -/// A [`Mapping`] sorted by its key. -/// -/// This type is recommended for mappings that should be stored and/or iterated by its sorted key. -pub type SortedMapping = BTreeMap; - -/// A [`BTreeMap`] storing the mapping between type variable ID and [unifier type][`Type`]. -pub type VarMap = SortedMapping; +/// The mapping between type variable ID and [unifier type][`Type`]. +pub type VarMap = IndexMapping; #[derive(Clone)] pub struct Call { @@ -920,8 +917,8 @@ impl Unifier { // Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits // all K-V pairs "in arbitrary order" let (tv1, tv2) = ( - params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), - params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), + params1.iter().map(|(_, v)| v).collect_vec(), + params2.iter().map(|(_, v)| v).collect_vec(), ); for (x, y) in zip(tv1, tv2) { if self.unify_impl(*x, *y, false).is_err() { @@ -1097,11 +1094,9 @@ impl Unifier { if params.is_empty() { name } else { - let params = params + let mut params = params .iter() .map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); - // sort to preserve order - let mut params = params.sorted(); format!("{}[{}]", name, params.join(", ")) } } @@ -1283,12 +1278,12 @@ impl Unifier { fn subst_map( &mut self, - map: &SortedMapping, + map: &IndexMapping, mapping: &VarMap, cache: &mut HashMap>, - ) -> Option> - where - K: Ord + Eq + Clone, + ) -> Option> + where + K: std::hash::Hash + Eq + Clone, { let mut map2 = None; for (k, v) in map { diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 24eece9f..fd15cfb2 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -45,9 +45,9 @@ impl Unifier { } } - fn map_eq(&mut self, map1: &SortedMapping, map2: &SortedMapping) -> bool - where - K: Ord + Eq + Clone, + fn map_eq(&mut self, map1: &IndexMapping, map2: &IndexMapping) -> bool + where + K: std::hash::Hash + Eq + Clone { if map1.len() != map2.len() { return false;