diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index ccba682..8371dc7 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -10,7 +10,7 @@ use nac3core::{ }, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{Type, TypeEnum, Unifier, VarMap}, + typedef::{SortedVarMap, Type, TypeEnum, Unifier}, }, }; use nac3parser::ast::{self, StrRef}; @@ -519,7 +519,7 @@ impl InnerResolver { .iter() .zip(args.iter()) .map(|((id, _), ty)| (*id, *ty)) - .collect::() + .collect::() }; Ok(Ok((unifier.subst(origin_ty, &subst).unwrap_or(origin_ty), true))) } @@ -722,7 +722,7 @@ impl InnerResolver { assert_eq!(*id, *id_var); (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) }) - .collect::(); + .collect::(); return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())) } @@ -734,7 +734,7 @@ impl InnerResolver { ))) } }; - let new_var_map: VarMap = params.iter().map(|(id, _)| (*id, ty)).collect(); + let new_var_map: SortedVarMap = params.iter().map(|(id, _)| (*id, ty)).collect(); let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty); Ok(Ok(res)) } @@ -751,7 +751,7 @@ impl InnerResolver { assert_eq!(*id, *id_var); (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) }) - .collect::(); + .collect::(); let mut instantiate_obj = || { // loop through non-function fields of the class to get the instantiated value for field in fields { diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 30de440..665638b 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -50,7 +50,7 @@ pub enum ConcreteTypeEnum { TObj { obj_id: DefinitionId, fields: HashMap, - params: HashMap, + params: Vec<(u32, ConcreteType)>, }, TVirtual { ty: ConcreteType, diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 21001c0..c9eee00 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -20,7 +20,7 @@ use crate::{ TopLevelDef, }, typecheck::{ - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{FunSignature, FuncArg, SortedVarMap, Type, TypeEnum, Unifier}, magic_methods::{binop_name, binop_assign_name}, }, }; @@ -42,7 +42,7 @@ use super::{CodeGenerator, llvm_intrinsics::call_memcpy_generic, need_sret}; pub fn get_subst_key( unifier: &mut Unifier, obj: Option, - fun_vars: &VarMap, + fun_vars: &SortedVarMap, filter: Option<&Vec>, ) -> String { let mut vars = obj @@ -50,10 +50,10 @@ pub fn get_subst_key( let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; - params.clone() + params.to_sorted() }) .unwrap_or_default(); - vars.extend(fun_vars.iter()); + vars.extend(fun_vars); let sorted = vars.keys().filter(|id| filter.map_or(true, |v| v.contains(id))).sorted(); sorted .map(|id| { @@ -86,7 +86,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fun: &FunSignature, filter: Option<&Vec>, ) -> String { - get_subst_key(&mut self.unifier, obj, &fun.vars, filter) + get_subst_key(&mut self.unifier, obj, &fun.vars.to_sorted(), filter) } pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { @@ -644,7 +644,7 @@ pub fn gen_func_instance<'ctx>( let mut filter = var_id.clone(); if let Some((obj_ty, _)) = &obj { if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty(*obj_ty) { - filter.extend(params.keys()); + filter.extend(params.ids()); } } let key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), sign, Some(&filter)); @@ -1977,7 +1977,6 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { let (ty, ndims) = params.iter() - .sorted_by_key(|(var_id, _)| *var_id) .map(|(_, ty)| ty) .collect_tuple() .unwrap(); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 159ac2b..1810a18 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -10,7 +10,7 @@ use crate::{ }, typecheck::{ type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, + typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }, }; use indoc::indoc; @@ -25,7 +25,6 @@ use nac3parser::{ use parking_lot::RwLock; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use crate::typecheck::typedef::VarMap; struct Resolver { id_to_type: HashMap, diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index de41ac1..35731d0 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -527,12 +527,12 @@ pub fn parse_type_annotation( let mut fields = fields .iter() .map(|(attr, ty, is_mutable)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + let ty = unifier.subst(*ty, &subst.to_sorted()).unwrap_or(*ty); (*attr, (ty, *is_mutable)) }) .collect::>(); fields.extend(methods.iter().map(|(attr, ty, _)| { - let ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + let ty = unifier.subst(*ty, &subst.to_sorted()).unwrap_or(*ty); (*attr, (ty, false)) })); Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index c51edae..89f6950 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -6,7 +6,7 @@ use crate::{ symbol_resolver::SymbolValue, typecheck::{ type_inferencer::{FunctionData, Inferencer}, - typedef::VarMap, + typedef::{SortedVarMap, VarMap}, }, }; @@ -779,7 +779,7 @@ impl TopLevelComposer { let mut new_fields = HashMap::new(); let mut need_subst = false; for (name, (ty, mutable)) in fields { - let substituted = unifier.subst(*ty, params); + let substituted = unifier.subst(*ty, ¶ms.to_sorted()); need_subst |= substituted.is_some(); new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable)); } @@ -1508,7 +1508,7 @@ impl TopLevelComposer { } = &mut *def.write() { if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = unifier.get_ty(*signature).as_ref() { - let new_var_ids = vars.values().map(|v| match &*unifier.get_ty(*v) { + let new_var_ids = vars.types().map(|v| match &*unifier.get_ty(*v) { TypeEnum::TVar{id, ..} => *id, _ => unreachable!(), }).collect_vec(); @@ -1516,7 +1516,7 @@ impl TopLevelComposer { let new_signature = FunSignature { args: args.clone(), ret: *ret, - vars: new_var_ids.iter().zip(vars.values()).map(|(id, v)| (*id, *v)).collect(), + vars: new_var_ids.iter().zip(vars.types()).map(|(id, v)| (*id, *v)).collect(), }; unifier.unification_table.set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature))); *var_id = new_var_ids; @@ -1612,7 +1612,7 @@ impl TopLevelComposer { }; constructor_args.extend_from_slice(args); - type_vars.extend(vars); + type_vars.extend(vars.to_sorted()); } } (constructor_args, type_vars) @@ -1738,7 +1738,7 @@ impl TopLevelComposer { let (type_var_subst_comb, no_range_vars) = { let mut no_ranges: Vec = Vec::new(); let var_combs = vars - .values() + .types() .map(|ty| { unifier.get_instantiations(*ty).unwrap_or_else(|| { let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) else { @@ -1752,13 +1752,13 @@ impl TopLevelComposer { }) .multi_cartesian_product() .collect_vec(); - let mut result: Vec = Vec::default(); + let mut result: Vec = Vec::default(); for comb in var_combs { - result.push(vars.keys().copied().zip(comb).collect()); + result.push(vars.ids().copied().zip(comb).collect()); } // NOTE: if is empty, means no type var, append a empty subst, ok to do this? if result.is_empty() { - result.push(VarMap::new()); + result.push(SortedVarMap::new()); } (result, no_ranges) }; @@ -1798,7 +1798,7 @@ impl TopLevelComposer { None } }) - .collect::() + .collect::() }; unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) }) @@ -1926,12 +1926,11 @@ impl TopLevelComposer { ret_str, name, ast.as_ref().unwrap().location - ), - ])) - } + ),])) + } instance_to_stmt.insert( - get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())), + get_subst_key(unifier, self_type, &subst, Some(&vars.ids().copied().collect())), FunInstance { body: Arc::new(fun_body), unifier_id: 0, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 1209f51..73d8b09 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -8,7 +8,7 @@ use std::{ use super::codegen::CodeGenContext; use super::typecheck::type_inferencer::PrimitiveStore; -use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap}; +use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, SortedVarMap}; use crate::{ codegen::CodeGenerator, symbol_resolver::{SymbolResolver, ValueEnum}, @@ -76,7 +76,7 @@ impl Debug for GenCall { pub struct FunInstance { pub body: Arc>>>, pub calls: Arc>, - pub subst: VarMap, + pub subst: SortedVarMap, pub unifier_id: usize, } diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index f9ac2eb..1ec10cf 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -3,7 +3,7 @@ use crate::{ toplevel::helper::PRIMITIVE_DEF_IDS, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{Type, TypeEnum, Unifier, VarMap}, + typedef::{Type, TypeEnum, Unifier, SortedVarMap}, }, }; @@ -28,11 +28,10 @@ pub fn make_ndarray_ty( let tvar_ids = params.iter() .map(|(obj_id, _)| *obj_id) - .sorted() .collect_vec(); debug_assert_eq!(tvar_ids.len(), 2); - let mut tvar_subst = VarMap::new(); + let mut tvar_subst = SortedVarMap::new(); if let Some(dtype) = dtype { tvar_subst.insert(tvar_ids[0], dtype); } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index c468b5f..7d24793 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -7,7 +7,7 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n", - "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 07fa75e..e32ae80 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,11 +3,11 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", - "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n", + "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index ddcf946..41ae68b 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,5 +1,6 @@ use crate::symbol_resolver::SymbolValue; use crate::toplevel::helper::PRIMITIVE_DEF_IDS; +use crate::typecheck::typedef::VarMap; use super::*; use nac3parser::ast::Constant; @@ -482,13 +483,13 @@ pub fn get_type_from_type_annotation_kinds( let mut tobj_fields = methods .iter() .map(|(name, ty, _)| { - let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + let subst_ty = unifier.subst(*ty, &subst.to_sorted()).unwrap_or(*ty); // methods are immutable (*name, (subst_ty, false)) }) .collect::>(); tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { - let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); + let subst_ty = unifier.subst(*ty, &subst.to_sorted()).unwrap_or(*ty); (*name, (subst_ty, *mutability)) })); let need_subst = !subst.is_empty(); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a6b72bb..9065468 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::convert::{From, TryInto}; use std::iter::once; use std::{cell::RefCell, sync::Arc}; -use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; +use super::typedef::{Call, FunSignature, FuncArg, RecordField, SortedVarMap, Type, TypeEnum, Unifier, VarMap}; use super::{magic_methods::*, typedef::CallId}; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, @@ -503,7 +503,7 @@ impl<'a> Fold<()> for Inferencer<'a> { assert_eq!(*id, *id_var); (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0) }) - .collect::(); + .collect::(); Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) } else { unreachable!("must be tobj") diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index c3c1c75..14f091b 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -26,13 +26,206 @@ pub struct CallId(pub(super) usize); pub type Mapping = HashMap; -/// A [`Mapping`] sorted by its key. -/// -/// This type is recommended for mappings that should be stored and/or iterated by its sorted key. -pub type SortedMapping = BTreeMap; +/// A sequenced collection storing type variable IDs and their [unifier types][Type], preserving the +/// order of type variables as it appears in the generic type or function declaration. This type +/// also guarantees the uniqueness of type variable IDs, i.e. only one type variable with the same +/// ID will be present in the map at all times. +#[derive(Clone, Default, PartialEq, Eq, Debug, Hash)] +pub struct VarMap { + var_ids: Vec, + var_tys: BTreeMap, +} -/// A [`BTreeMap`] storing the mapping between type variable ID and [unifier type][`Type`]. -pub type VarMap = SortedMapping; +impl VarMap { + /// Creates a new, empty [`VarMap`]. + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Returns the number of elements in this instance. + #[must_use] + pub fn len(&self) -> usize { + self.var_ids.len() + } + + /// Returns `true` if this instance contains no elements. + #[must_use] + pub fn is_empty(&self) -> bool { + self.var_ids.is_empty() + } + + /// Creates a consuming iterator visiting all entries of this instance as it appears in + /// declaration order. + /// + /// See [`IntoIterator::into_iter`]. + /// + /// Implementation Note: This function should be implemented as part of [`IntoIterator`], but + /// [`impl_trait_in_assoc_type`](https://github.com/rust-lang/rust/issues/63063) is not + /// stabilized. + fn into_iter(self) -> impl Iterator { + self.var_tys + .into_iter() + .sorted_by_key(|(id, _)| self.var_ids.iter().position(|var_id| id == var_id).unwrap()) + } + + /// Returns an iterator over the entries of this instance as it appears in declaration order. + pub fn iter(&self) -> impl Iterator { + self.var_tys + .iter() + .sorted_by_key(|(id, _)| self.var_ids.iter().position(|var_id| *id == var_id).unwrap()) + } + + /// Returns an iterator that allows modifying each value as it appears in declaration order. + pub fn iter_mut(&mut self) -> impl Iterator { + self.var_tys + .iter_mut() + .sorted_by_key(|(id, _)| self.var_ids.iter().position(|var_id| *id == var_id).unwrap()) + } + + /// Returns an iterator over the type variable IDs as it appears in declaration order. + pub fn ids(&self) -> impl Iterator { + self.var_ids.iter() + } + + /// Returns an iterator over the types of type variables as it appears in declaration order. + pub fn types(&self) -> impl Iterator { + self.iter().map(|(_, ty)| ty) + } + + /// Returns `true` if this instance contains the given type variable ID. + #[must_use] + pub fn contains_id(&self, var_id: &u32) -> bool { + self.get_key_value(var_id).is_some() + } + + /// Returns a reference to the type corresponding to the type variable ID. + #[must_use] + pub fn get(&self, var_id: &u32) -> Option<&Type> { + self.get_key_value(var_id).map(|(_, ty)| ty) + } + + /// Returns a mutable reference to the type corresponding to the type variable ID. + pub fn get_mut(&mut self, var_id: &u32) -> Option<&mut Type> { + self.var_tys.get_mut(var_id) + } + + /// Returns a reference to the pair of type variable ID and type corresponding to the ID. + #[must_use] + pub fn get_key_value(&self, var_id: &u32) -> Option<(&u32, &Type)> { + self.var_tys.get_key_value(var_id) + } + + /// Returns references to the type variable ID and associated type at the given declaration + /// `index`. + #[must_use] + pub fn index(&self, index: usize) -> Option<(&u32, &Type)> { + self.iter().nth(index) + } + + /// Returns mutable references to the type variable ID and associated type at the given + /// declaration `index`. + pub fn index_mut(&mut self, index: usize) -> Option<(&u32, &mut Type)> { + self.iter_mut().nth(index) + } + + /// Creates a [`SortedVarMap`] containing all entries in this instance. + #[must_use] + pub fn to_sorted(&self) -> SortedVarMap { + self.into() + } + + /// Inserts a type variable mapping into this instance. + /// + /// If the type variable ID already exists, the `ty` associated with the type variable will be + /// replaced. + pub fn insert(&mut self, var_id: u32, ty: Type) -> Option { + self.insert_impl(var_id, ty, None) + } + + /// Inserts a type variable mapping into this instance at the given `index`. + /// + /// If the type variable ID already exists, the `ty` associated with the type variable will be + /// replaced. + pub fn insert_at(&mut self, var_id: u32, ty: Type, index: usize) -> Option { + self.insert_impl(var_id, ty, Some(index)) + } +} + +impl VarMap { + fn insert_impl(&mut self, var_id: u32, ty: Type, index: Option) -> Option { + let old = self.var_tys.insert(var_id, ty); + if old.is_none() { + if let Some(index) = index { + self.var_ids.insert(index, var_id); + } else { + self.var_ids.push(var_id); + } + } + + assert_eq!(self.var_ids.len(), self.var_tys.len()); + + old + } +} + +impl Extend<(u32, Type)> for VarMap { + fn extend>(&mut self, iter: T) { + iter.into_iter().for_each(move |(k, v)| { + self.insert(k, v); + }); + } +} + +impl<'a> Extend<&'a (u32, Type)> for VarMap { + fn extend>(&mut self, iter: T) { + iter.into_iter().for_each(move |(k, v)| { + self.insert(*k, *v); + }); + } +} + +impl From<[(u32, Type); N]> for VarMap { + fn from(value: [(u32, Type); N]) -> Self { + if N == 0 { + return VarMap::new() + } + + value.into_iter().collect() + } +} + +impl FromIterator<(u32, Type)> for VarMap { + fn from_iter>(iter: T) -> Self { + let vars: Vec<(u32, Type)> = iter.into_iter().collect(); + let var_ids = vars.iter() + .map(|(id, _)| *id) + .unique() + .collect(); + let var_tys = vars.into_iter().collect(); + + Self { var_ids, var_tys } + } +} + +/// A [mapping][`Mapping`] between type variable IDs and [unifier type][`Type`]. +/// +/// As opposed to [`VarMap`], this type does not preserve the order of type variables as it appears +/// in the generic declaration. As such, this type should only be used when lookup operations should +/// be prioritized *and* the declaration order of type variables does not matter. +pub type SortedVarMap = Mapping; + +impl From for SortedVarMap { + fn from(value: VarMap) -> Self { + value.var_tys.into_iter().collect() + } +} + +impl From<&VarMap> for SortedVarMap { + fn from(value: &VarMap) -> Self { + value.var_tys.iter().map(|(id, ty)| (*id, *ty)).collect() + } +} #[derive(Clone)] pub struct Call { @@ -486,7 +679,7 @@ impl Unifier { TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TObj { params: vars, .. } => { - vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) + vars.types().all(|ty| self.is_concrete(*ty, allowed_typevars)) } } } @@ -920,8 +1113,8 @@ impl Unifier { // Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits // all K-V pairs "in arbitrary order" let (tv1, tv2) = ( - params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), - params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(), + params1.iter().map(|(_, v)| v).collect_vec(), + params2.iter().map(|(_, v)| v).collect_vec(), ); for (x, y) in zip(tv1, tv2) { if self.unify_impl(*x, *y, false).is_err() { @@ -1097,11 +1290,9 @@ impl Unifier { if params.is_empty() { name } else { - let params = params + let mut params = params .iter() .map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); - // sort to preserve order - let mut params = params.sorted(); format!("{}[{}]", name, params.join(", ")) } } @@ -1151,7 +1342,7 @@ impl Unifier { fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { let mut instantiated = true; let mut vars = Vec::new(); - for (k, v) in &fun.vars { + for (k, v) in fun.vars.iter() { if let TypeEnum::TVar { id, name, loc, range, .. } = self.unification_table.probe_value(*v).as_ref() { @@ -1181,14 +1372,14 @@ impl Unifier { /// If this returns Some(T), T would be the substituted type. /// If this returns None, the result type would be the original type /// (no substitution has to be done). - pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option { + pub fn subst(&mut self, a: Type, mapping: &SortedVarMap) -> Option { self.subst_impl(a, mapping, &mut HashMap::new()) } fn subst_impl( &mut self, a: Type, - mapping: &VarMap, + mapping: &SortedVarMap, cache: &mut HashMap>, ) -> Option { let cached = cache.get_mut(&a); @@ -1231,7 +1422,7 @@ impl Unifier { // If the mapping does not contain any type variables in the // parameter list, we don't need to substitute the fields. // This is also used to prevent infinite substitution... - let need_subst = params.values().any(|v| { + let need_subst = params.types().any(|v| { let ty = self.unification_table.probe_value(*v); if let TypeEnum::TVar { id, .. } = ty.as_ref() { mapping.contains_key(id) @@ -1281,17 +1472,14 @@ impl Unifier { } } - fn subst_map( + fn subst_map( &mut self, - map: &SortedMapping, - mapping: &VarMap, + map: &VarMap, + mapping: &SortedVarMap, cache: &mut HashMap>, - ) -> Option> - where - K: Ord + Eq + Clone, - { + ) -> Option { let mut map2 = None; - for (k, v) in map { + for (k, v) in map.iter() { if let Some(v1) = self.subst_impl(*v, mapping, cache) { if map2.is_none() { map2 = Some(map.clone()); @@ -1305,7 +1493,7 @@ impl Unifier { fn subst_map2( &mut self, map: &Mapping, - mapping: &VarMap, + mapping: &SortedVarMap, cache: &mut HashMap>, ) -> Option> where diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 24eece9..2a050c0 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -45,10 +45,7 @@ impl Unifier { } } - fn map_eq(&mut self, map1: &SortedMapping, map2: &SortedMapping) -> bool - where - K: Ord + Eq + Clone, - { + fn map_eq(&mut self, map1: &VarMap, map2: &VarMap) -> bool { if map1.len() != map2.len() { return false; } @@ -186,7 +183,7 @@ impl TestEnvironment { s = &s[1..]; ty = self .unifier - .subst(ty, ¶ms.keys().cloned().zip(p.into_iter()).collect()) + .subst(ty, ¶ms.ids().cloned().zip(p.into_iter()).collect()) .unwrap_or(ty); } }