forked from M-Labs/nac3
1
0
Fork 0

core: Refactor VarMap to IndexMap

This is the only Map I can find that preserves insertion order while
also deduplicating elements by key.
This commit is contained in:
David Mak 2024-03-14 13:21:56 +08:00
parent f0da9c0283
commit 13f06f3e29
12 changed files with 30 additions and 35 deletions

1
Cargo.lock generated
View File

@ -616,6 +616,7 @@ name = "nac3core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"crossbeam", "crossbeam",
"indexmap 2.2.5",
"indoc", "indoc",
"inkwell", "inkwell",
"insta", "insta",

View File

@ -7,6 +7,7 @@ edition = "2021"
[dependencies] [dependencies]
itertools = "0.12" itertools = "0.12"
crossbeam = "0.8" crossbeam = "0.8"
indexmap = "2.2"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.8" rayon = "1.8"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }

View File

@ -9,6 +9,7 @@ use crate::{
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use std::collections::HashMap; use std::collections::HashMap;
use indexmap::IndexMap;
pub struct ConcreteTypeStore { pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>, store: Vec<ConcreteTypeEnum>,
@ -50,7 +51,7 @@ pub enum ConcreteTypeEnum {
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>, fields: HashMap<StrRef, (ConcreteType, bool)>,
params: HashMap<u32, ConcreteType>, params: IndexMap<u32, ConcreteType>,
}, },
TVirtual { TVirtual {
ty: ConcreteType, ty: ConcreteType,

View File

@ -59,7 +59,7 @@ pub fn get_subst_key(
params.clone() params.clone()
}) })
.unwrap_or_default(); .unwrap_or_default();
vars.extend(fun_vars.iter()); vars.extend(fun_vars);
let sorted = vars.keys().filter(|id| filter.map_or(true, |v| v.contains(id))).sorted(); let sorted = vars.keys().filter(|id| filter.map_or(true, |v| v.contains(id))).sorted();
sorted sorted
.map(|id| { .map(|id| {
@ -1983,7 +1983,6 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (ty, ndims) = params.iter() let (ty, ndims) = params.iter()
.sorted_by_key(|(var_id, _)| *var_id)
.map(|(_, ty)| ty) .map(|(_, ty)| ty)
.collect_tuple() .collect_tuple()
.unwrap(); .unwrap();

View File

@ -10,7 +10,7 @@ use crate::{
}, },
typecheck::{ typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use indoc::indoc; use indoc::indoc;
@ -25,7 +25,6 @@ use nac3parser::{
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use crate::typecheck::typedef::VarMap;
struct Resolver { struct Resolver {
id_to_type: HashMap<StrRef, Type>, id_to_type: HashMap<StrRef, Type>,

View File

@ -1926,9 +1926,8 @@ impl TopLevelComposer {
ret_str, ret_str,
name, name,
ast.as_ref().unwrap().location ast.as_ref().unwrap().location
), ),]))
])) }
}
instance_to_stmt.insert( instance_to_stmt.insert(
get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())), get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())),

View File

@ -28,7 +28,6 @@ pub fn make_ndarray_ty(
let tvar_ids = params.iter() let tvar_ids = params.iter()
.map(|(obj_id, _)| *obj_id) .map(|(obj_id, _)| *obj_id)
.sorted()
.collect_vec(); .collect_vec();
debug_assert_eq!(tvar_ids.len(), 2); debug_assert_eq!(tvar_ids.len(), 2);

View File

@ -7,7 +7,7 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
] ]

View File

@ -3,11 +3,11 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
] ]

View File

@ -1,5 +1,6 @@
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::typecheck::typedef::VarMap;
use super::*; use super::*;
use nac3parser::ast::Constant; use nac3parser::ast::Constant;

View File

@ -1,10 +1,11 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap}; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet}; use std::{borrow::Cow, collections::HashSet};
use std::iter::zip; use std::iter::zip;
use indexmap::IndexMap;
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::{Location, StrRef}; use nac3parser::ast::{Location, StrRef};
@ -25,14 +26,10 @@ pub type Type = UnificationKey;
pub struct CallId(pub(super) usize); pub struct CallId(pub(super) usize);
pub type Mapping<K, V = Type> = HashMap<K, V>; pub type Mapping<K, V = Type> = HashMap<K, V>;
pub type IndexMapping<K, V = Type> = IndexMap<K, V>;
/// A [`Mapping`] sorted by its key. /// The mapping between type variable ID and [unifier type][`Type`].
/// pub type VarMap = IndexMapping<u32>;
/// This type is recommended for mappings that should be stored and/or iterated by its sorted key.
pub type SortedMapping<K, V = Type> = BTreeMap<K, V>;
/// A [`BTreeMap`] storing the mapping between type variable ID and [unifier type][`Type`].
pub type VarMap = SortedMapping<u32>;
#[derive(Clone)] #[derive(Clone)]
pub struct Call { pub struct Call {
@ -920,8 +917,8 @@ impl Unifier {
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits // Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
// all K-V pairs "in arbitrary order" // all K-V pairs "in arbitrary order"
let (tv1, tv2) = ( let (tv1, tv2) = (
params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), params1.iter().map(|(_, v)| v).collect_vec(),
params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), params2.iter().map(|(_, v)| v).collect_vec(),
); );
for (x, y) in zip(tv1, tv2) { for (x, y) in zip(tv1, tv2) {
if self.unify_impl(*x, *y, false).is_err() { if self.unify_impl(*x, *y, false).is_err() {
@ -1097,11 +1094,9 @@ impl Unifier {
if params.is_empty() { if params.is_empty() {
name name
} else { } else {
let params = params let mut params = params
.iter() .iter()
.map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); .map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
// sort to preserve order
let mut params = params.sorted();
format!("{}[{}]", name, params.join(", ")) format!("{}[{}]", name, params.join(", "))
} }
} }
@ -1283,12 +1278,12 @@ impl Unifier {
fn subst_map<K>( fn subst_map<K>(
&mut self, &mut self,
map: &SortedMapping<K>, map: &IndexMapping<K>,
mapping: &VarMap, mapping: &VarMap,
cache: &mut HashMap<Type, Option<Type>>, cache: &mut HashMap<Type, Option<Type>>,
) -> Option<SortedMapping<K>> ) -> Option<IndexMapping<K>>
where where
K: Ord + Eq + Clone, K: std::hash::Hash + Eq + Clone,
{ {
let mut map2 = None; let mut map2 = None;
for (k, v) in map { for (k, v) in map {

View File

@ -45,9 +45,9 @@ impl Unifier {
} }
} }
fn map_eq<K>(&mut self, map1: &SortedMapping<K>, map2: &SortedMapping<K>) -> bool fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where where
K: Ord + Eq + Clone, K: std::hash::Hash + Eq + Clone
{ {
if map1.len() != map2.len() { if map1.len() != map2.len() {
return false; return false;