hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
2 changed files with 177 additions and 56 deletions
Showing only changes of commit d67407716c - Show all commits

View File

@ -8,14 +8,12 @@ mod test {
struct TestEnvironment { struct TestEnvironment {
pub unifier: Unifier, pub unifier: Unifier,
type_mapping: HashMap<String, Type>, type_mapping: HashMap<String, Type>,
var_max_id: u32,
} }
impl TestEnvironment { impl TestEnvironment {
fn new() -> TestEnvironment { fn new() -> TestEnvironment {
let unifier = Unifier::new(); let mut unifier = Unifier::new();
let mut type_mapping = HashMap::new(); let mut type_mapping = HashMap::new();
let mut var_max_id = 0;
type_mapping.insert( type_mapping.insert(
"int".into(), "int".into(),
@ -41,38 +39,30 @@ mod test {
params: HashMap::new(), params: HashMap::new(),
}), }),
); );
let v0 = unifier.add_ty(TypeEnum::TVar { id: 0 }); let (v0, id) = unifier.get_fresh_var();
var_max_id += 1;
type_mapping.insert( type_mapping.insert(
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: 3, obj_id: 3,
fields: [("a".into(), v0)].iter().cloned().collect(), fields: [("a".into(), v0)].iter().cloned().collect(),
params: [(0u32, v0)].iter().cloned().collect(), params: [(id, v0)].iter().cloned().collect(),
}), }),
); );
TestEnvironment { TestEnvironment {
unifier, unifier,
type_mapping, type_mapping,
var_max_id,
} }
} }
fn get_fresh_var(&mut self) -> Type { fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
let id = self.var_max_id + 1;
self.var_max_id += 1;
self.unifier.add_ty(TypeEnum::TVar { id })
}
fn parse(&self, typ: &str, mapping: &Mapping<String>) -> Type {
let result = self.internal_parse(typ, mapping); let result = self.internal_parse(typ, mapping);
assert!(result.1.is_empty()); assert!(result.1.is_empty());
result.0 result.0
} }
fn internal_parse<'a, 'b>( fn internal_parse<'a, 'b>(
&'a self, &'a mut self,
typ: &'b str, typ: &'b str,
mapping: &Mapping<String>, mapping: &Mapping<String>,
) -> (Type, &'b str) { ) -> (Type, &'b str) {
@ -189,8 +179,8 @@ mod test {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.get_fresh_var(); let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v); mapping.insert(format!("v{}", i), v.0);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.
@ -259,8 +249,8 @@ mod test {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.get_fresh_var(); let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v); mapping.insert(format!("v{}", i), v.0);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.

View File

@ -48,7 +48,7 @@ impl Deref for TypeCell {
} }
pub type Mapping<K, V = Type> = HashMap<K, V>; pub type Mapping<K, V = Type> = HashMap<K, V>;
pub type VarMap = Mapping<u32>; type VarMap = Mapping<u32>;
#[derive(Clone)] #[derive(Clone)]
pub struct Call { pub struct Call {
@ -65,6 +65,13 @@ pub struct FuncArg {
is_optional: bool, is_optional: bool,
} }
#[derive(Clone)]
pub struct FunSignature {
args: Vec<FuncArg>,
ret: Type,
params: VarMap,
}
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
// We may not really need so much `Rc`s, but we would have to do complicated // We may not really need so much `Rc`s, but we would have to do complicated
// stuffs otherwise. // stuffs otherwise.
@ -96,11 +103,7 @@ pub enum TypeEnum {
TCall { TCall {
calls: Vec<Rc<Call>>, calls: Vec<Rc<Call>>,
}, },
TFunc { TFunc(FunSignature),
args: Vec<FuncArg>,
ret: Type,
params: VarMap,
},
} }
// Order: // Order:
@ -199,40 +202,41 @@ pub struct ObjDef {
} }
pub struct Unifier { pub struct Unifier {
unification_table: RefCell<InPlaceUnificationTable<Type>>, unification_table: InPlaceUnificationTable<Type>,
obj_def_table: Vec<ObjDef>, obj_def_table: Vec<ObjDef>,
var_id: u32,
} }
impl Unifier { impl Unifier {
pub fn new() -> Unifier { pub fn new() -> Unifier {
Unifier { Unifier {
unification_table: RefCell::new(InPlaceUnificationTable::new()), unification_table: InPlaceUnificationTable::new(),
obj_def_table: Vec::new(), obj_def_table: Vec::new(),
var_id: 0,
} }
} }
/// Register a type to the unifier. /// Register a type to the unifier.
/// Returns a key in the unification_table. /// Returns a key in the unification_table.
pub fn add_ty(&self, a: TypeEnum) -> Type { pub fn add_ty(&mut self, a: TypeEnum) -> Type {
self.unification_table self.unification_table.new_key(TypeCell(Rc::new(a.into())))
.borrow_mut()
.new_key(TypeCell(Rc::new(a.into())))
} }
/// Get the TypeEnum of a type. /// Get the TypeEnum of a type.
pub fn get_ty(&self, a: Type) -> Rc<RefCell<TypeEnum>> { pub fn get_ty(&mut self, a: Type) -> Rc<RefCell<TypeEnum>> {
let mut table = self.unification_table.borrow_mut(); self.unification_table.probe_value(a).0
table.probe_value(a).0
} }
/// Unify two types, i.e. a = b. /// Unify two types, i.e. a = b.
pub fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> { pub fn unify(&mut self, mut a: Type, mut b: Type) -> Result<(), String> {
let (mut ty_a_cell, mut ty_b_cell) = { let (mut ty_a_cell, mut ty_b_cell) = {
let mut table = self.unification_table.borrow_mut(); if self.unification_table.unioned(a, b) {
if table.unioned(a, b) {
return Ok(()); return Ok(());
} }
(table.probe_value(a), table.probe_value(b)) (
self.unification_table.probe_value(a),
self.unification_table.probe_value(b),
)
}; };
let (ty_a, ty_b) = { let (ty_a, ty_b) = {
@ -353,7 +357,6 @@ impl Unifier {
TypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty } => {
// not sure if this is correct... // not sure if this is correct...
self.unify(a, *ty)?; self.unify(a, *ty)?;
self.set_a_to_b(a, b);
} }
_ => { _ => {
return self.incompatible_types(&*ty_a, &*ty_b); return self.incompatible_types(&*ty_a, &*ty_b);
@ -390,14 +393,105 @@ impl Unifier {
return self.incompatible_types(&*ty_a, &*ty_b); return self.incompatible_types(&*ty_a, &*ty_b);
} }
} }
_ => unimplemented!(), TypeEnum::TCall { calls: c1 } => match &*ty_b {
TypeEnum::TCall { .. } => {
drop(ty_b);
if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
c2.extend(c1.iter().cloned());
} else {
unreachable!()
}
self.set_a_to_b(a, b);
}
TypeEnum::TFunc(signature) => {
let required: Vec<String> = signature
.args
.iter()
.filter(|v| !v.is_optional)
.map(|v| v.name.clone())
.rev()
.collect();
for c in c1 {
let Call {
posargs,
kwargs,
ret,
fun,
} = c.as_ref();
let instantiated = self.instantiate_fun(b, signature);
let signature;
let r = self.get_ty(instantiated);
let r = r.as_ref().borrow();
if let TypeEnum::TFunc(s) = &*r {
signature = s;
} else {
unreachable!();
}
let mut required = required.clone();
let mut all_names: Vec<_> = signature
.args
.iter()
.map(|v| (v.name.clone(), v.ty))
.rev()
.collect();
for (i, t) in posargs.iter().enumerate() {
if signature.args.len() <= i {
return Err(format!("Too many arguments."));
}
if !required.is_empty() {
required.pop();
}
self.unify(all_names.pop().unwrap().1, *t)?;
}
for (k, t) in kwargs.iter() {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
if let Some(i) = all_names.iter().position(|v| &v.0 == k) {
self.unify(all_names.remove(i).1, *t)?;
} else {
return Err(format!("Unknown keyword argument {}", k));
}
}
self.unify(*ret, signature.ret)?;
*fun.borrow_mut() = Some(instantiated);
}
self.set_a_to_b(a, b);
}
_ => {
return self.incompatible_types(&*ty_a, &*ty_b);
}
},
TypeEnum::TFunc(sign1) => {
if let TypeEnum::TFunc(sign2) = &*ty_b {
if !sign1.params.is_empty() || !sign2.params.is_empty() {
return Err(format!("Polymorphic function pointer is prohibited."));
}
if sign1.args.len() != sign2.args.len() {
return Err(format!("Functions differ in number of parameters."));
}
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
if x.name != y.name {
return Err(format!("Functions differ in parameter names."));
}
if x.is_optional != y.is_optional {
return Err(format!("Functions differ in optional parameters."));
}
self.unify(x.ty, y.ty)?;
}
self.unify(sign1.ret, sign2.ret)?;
self.set_a_to_b(a, b);
} else {
return self.incompatible_types(&*ty_a, &*ty_b);
}
}
} }
Ok(()) Ok(())
} }
fn set_a_to_b(&self, a: Type, b: Type) { fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value. // unify a and b together, and set the value to b's value.
let mut table = self.unification_table.borrow_mut(); let table = &mut self.unification_table;
let ty_b = table.probe_value(b); let ty_b = table.probe_value(b);
table.union(a, b); table.union(a, b);
table.union_value(a, ty_b); table.union_value(a, ty_b);
@ -411,11 +505,11 @@ impl Unifier {
)) ))
} }
fn occur_check(&self, a: Type, b: Type) -> Result<(), String> { fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
if self.unification_table.borrow_mut().unioned(a, b) { if self.unification_table.unioned(a, b) {
return Err("Recursive type is prohibited.".to_owned()); return Err("Recursive type is prohibited.".to_owned());
} }
let ty = self.unification_table.borrow_mut().probe_value(b); let ty = self.unification_table.probe_value(b);
let ty = ty.borrow(); let ty = ty.borrow();
match &*ty { match &*ty {
@ -454,7 +548,7 @@ impl Unifier {
self.occur_check(a, *t)?; self.occur_check(a, *t)?;
} }
} }
TypeEnum::TFunc { args, ret, params } => { TypeEnum::TFunc(FunSignature { args, ret, params }) => {
for t in args for t in args
.iter() .iter()
.map(|v| &v.ty) .map(|v| &v.ty)
@ -472,8 +566,8 @@ impl Unifier {
/// If this returns Some(T), T would be the substituted type. /// If this returns Some(T), T would be the substituted type.
/// If this returns None, the result type would be the original type /// If this returns None, the result type would be the original type
/// (no substitution has to be done). /// (no substitution has to be done).
pub fn subst(&self, a: Type, mapping: &VarMap) -> Option<Type> { pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
let ty_cell = self.unification_table.borrow_mut().probe_value(a); let ty_cell = self.unification_table.probe_value(a);
let ty = ty_cell.borrow(); let ty = ty_cell.borrow();
// this function would only be called when we instantiate functions. // this function would only be called when we instantiate functions.
// function type signature should ONLY contain concrete types and type // function type signature should ONLY contain concrete types and type
@ -512,10 +606,10 @@ impl Unifier {
// parameter list, we don't need to substitute the fields. // parameter list, we don't need to substitute the fields.
// This is also used to prevent infinite substitution... // This is also used to prevent infinite substitution...
let need_subst = params.values().any(|v| { let need_subst = params.values().any(|v| {
let ty_cell = self.unification_table.borrow_mut().probe_value(*v); let ty_cell = self.unification_table.probe_value(*v);
let ty = ty_cell.borrow(); let ty = ty_cell.borrow();
if let TypeEnum::TVar { id } = &*ty { if let TypeEnum::TVar { id } = &*ty {
mapping.contains_key(id) mapping.contains_key(&id)
} else { } else {
false false
} }
@ -537,7 +631,7 @@ impl Unifier {
None None
} }
} }
TypeEnum::TFunc { args, ret, params } => { TypeEnum::TFunc(FunSignature { args, ret, params }) => {
let new_params = self.subst_map(params, mapping); let new_params = self.subst_map(params, mapping);
let new_ret = self.subst(*ret, mapping); let new_ret = self.subst(*ret, mapping);
let mut new_args = None; let mut new_args = None;
@ -557,7 +651,7 @@ impl Unifier {
let params = new_params.unwrap_or_else(|| params.clone()); let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or_else(|| *ret); let ret = new_ret.unwrap_or_else(|| *ret);
let args = new_args.unwrap_or_else(|| args.clone()); let args = new_args.unwrap_or_else(|| args.clone());
Some(self.add_ty(TypeEnum::TFunc { params, ret, args })) Some(self.add_ty(TypeEnum::TFunc(FunSignature { params, ret, args })))
} else { } else {
None None
} }
@ -566,7 +660,7 @@ impl Unifier {
} }
} }
fn subst_map<K>(&self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>> fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
where where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{ {
@ -582,13 +676,43 @@ impl Unifier {
map2 map2
} }
/// Instantiate a function if it hasn't been instntiated.
/// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = false;
for (k, v) in fun.params.iter() {
if let TypeEnum::TVar { id } =
&*self.unification_table.probe_value(*v).as_ref().borrow()
{
if k != id {
instantiated = true;
break;
}
} else {
instantiated = true;
break;
}
}
if instantiated {
ty
} else {
let mapping = fun
.params
.iter()
.map(|(k, _)| (*k, self.get_fresh_var().0))
.collect();
self.subst(ty, &mapping).unwrap_or(ty)
}
}
/// Check whether two types are equal. /// Check whether two types are equal.
pub fn eq(&self, a: Type, b: Type) -> bool { pub fn eq(&mut self, a: Type, b: Type) -> bool {
if a == b { if a == b {
return true; return true;
} }
let (ty_a, ty_b) = { let (ty_a, ty_b) = {
let mut table = self.unification_table.borrow_mut(); let table = &mut self.unification_table;
if table.unioned(a, b) { if table.unioned(a, b) {
return true; return true;
} }
@ -629,7 +753,7 @@ impl Unifier {
} }
} }
fn map_eq<K>(&self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
where where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{ {
@ -643,4 +767,11 @@ impl Unifier {
} }
true true
} }
/// Get a fresh type variable.
pub fn get_fresh_var(&mut self) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TVar { id }), id)
}
} }