1
0
forked from M-Labs/nac3

type scheme instantiation

This commit is contained in:
pca006132 2021-06-30 17:18:56 +08:00
parent 2985b88351
commit 84c980fed3
2 changed files with 32 additions and 3 deletions
nac3core/src/typecheck

View File

@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::collections::HashSet;
use super::primitives::get_var;
use super::symbol_resolver::*;
@ -89,6 +90,13 @@ impl<'a> InferenceContext<'a> {
VariableId(id)
}
fn get_fresh_var_with_bound(&mut self, bound: Vec<Type>) -> VariableId {
self.local_variables.push(VarDef { name: None, bound });
let id = self.fresh_var_id;
self.fresh_var_id += 1;
VariableId(id)
}
pub fn assign_identifier(&mut self, identifier: &'a str) -> Type {
if let Some(t) = self.local_identifiers.get(identifier) {
t.clone()
@ -141,22 +149,42 @@ impl<'a> InferenceContext<'a> {
.base
.fields
.get(identifier)
.map_or_else(|| Err("no such field".to_owned()), |v| Ok(v))?;
.map_or_else(|| Err("no such field".to_owned()), Ok)?;
// function and tuple can have 0 type variables but with type parameters
// we require other types have the same number of type variables and type
// parameters in order to build a mapping
assert!(type_def.params.len() == 0 || type_def.params.len() == params.len());
assert!(type_def.params.is_empty() || type_def.params.len() == params.len());
let map = type_def
.params
.clone()
.into_iter()
.zip(params.clone().into_iter())
.collect();
Ok(field.subst(&map))
let field = field.subst(&map);
Ok(self.get_instance(field))
}
}
}
fn get_instance(&mut self, t: Type) -> Type {
let mut vars = HashSet::new();
t.get_vars(&mut vars);
let local_min = self.global.get_var_count();
let bounded = vars.into_iter().filter(|id| id.0 < local_min);
let map = bounded
.map(|v| {
(
v,
get_var(
self.get_fresh_var_with_bound(self.global.get_var_def(v).bound.clone()),
),
)
})
.collect();
t.subst(&map)
}
pub fn get_type_def(&self, id: TypeId) -> &TypeDef {
self.global.get_type_def(id)
}

View File

@ -1,3 +1,4 @@
#![allow(dead_code)]
mod context;
pub mod location;
mod magic_methods;