hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
2 changed files with 194 additions and 269 deletions
Showing only changes of commit e732f7e089 - Show all commits

View File

@ -206,7 +206,7 @@ mod test {
("v1", "Tuple[int]"), ("v1", "Tuple[int]"),
("v2", "List[int]"), ("v2", "List[int]"),
], ],
(("v1", "v2"), "Cannot unify TTuple with TList") (("v1", "v2"), "Cannot unify TList with TTuple")
; "type mismatch" ; "type mismatch"
)] )]
#[test_case(2, #[test_case(2,
@ -222,7 +222,7 @@ mod test {
("v1", "Tuple[int,int]"), ("v1", "Tuple[int,int]"),
("v2", "Tuple[int]"), ("v2", "Tuple[int]"),
], ],
(("v1", "v2"), "Cannot unify tuples with length 1 and 2") (("v1", "v2"), "Cannot unify tuples with length 2 and 1")
; "tuple length mismatch" ; "tuple length mismatch"
)] )]
#[test_case(3, #[test_case(3,

View File

@ -3,7 +3,6 @@ use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::iter::once; use std::iter::once;
use std::mem::swap;
use std::ops::Deref; use std::ops::Deref;
use std::rc::Rc; use std::rc::Rc;
@ -69,7 +68,7 @@ pub struct FuncArg {
pub struct FunSignature { pub struct FunSignature {
pub args: Vec<FuncArg>, pub args: Vec<FuncArg>,
pub ret: Type, pub ret: Type,
pub params: VarMap, pub vars: VarMap,
} }
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
@ -117,62 +116,7 @@ pub enum TypeEnum {
// `--> TCall // `--> TCall
// `--> TFunc // `--> TFunc
// We encode the types as natural numbers, and subtyping relation as divisibility.
// If a | b, b <: a.
// We assign unique prime numbers (1 to TVar, everything is a subtype of it) to each type:
// TVar = 1
// |--> TSeq = 2
// | |--> TTuple = 3
// | `--> TList = 5
// |--> TRecord = 7
// | |--> TObj = 11
// | `--> TVirtual = 13
// `--> TCall = 17
// `--> TFunc = 21
//
// And then, based on the subtyping relation, multiply them together...
// TVar = 1
// |--> TSeq = 2 * TVar
// | |--> TTuple = 3 * TSeq * TVar
// | `--> TList = 5 * TSeq * TVar
// |--> TRecord = 7 * TVar
// | |--> TObj = 11 * TRecord * TVar
// | `--> TVirtual = 13 * TRecord * TVar
// `--> TCall = 17 * TVar
// `--> TFunc = 21 * TCall * TVar
impl TypeEnum { impl TypeEnum {
fn get_int(&self) -> i32 {
const TVAR: i32 = 1;
const TSEQ: i32 = 2;
const TTUPLE: i32 = 3;
const TLIST: i32 = 5;
const TRECORD: i32 = 7;
const TOBJ: i32 = 11;
const TVIRTUAL: i32 = 13;
const TCALL: i32 = 17;
const TFUNC: i32 = 21;
match self {
TypeEnum::TVar { .. } => TVAR,
TypeEnum::TSeq { .. } => TSEQ * TVAR,
TypeEnum::TTuple { .. } => TTUPLE * TSEQ * TVAR,
TypeEnum::TList { .. } => TLIST * TSEQ * TVAR,
TypeEnum::TRecord { .. } => TRECORD * TVAR,
TypeEnum::TObj { .. } => TOBJ * TRECORD * TVAR,
TypeEnum::TVirtual { .. } => TVIRTUAL * TRECORD * TVAR,
TypeEnum::TCall { .. } => TCALL * TVAR,
TypeEnum::TFunc { .. } => TFUNC * TCALL * TVAR,
}
}
// e.g. List <: Var
pub fn type_le(&self, other: &TypeEnum) -> bool {
let a = self.get_int();
let b = other.get_int();
(a % b) == 0
}
pub fn get_type_name(&self) -> &'static str { pub fn get_type_name(&self) -> &'static str {
// this function is for debugging only... // this function is for debugging only...
// a proper to_str implementation requires the context // a proper to_str implementation requires the context
@ -227,9 +171,14 @@ impl Unifier {
self.unification_table.probe_value(a).0 self.unification_table.probe_value(a).0
} }
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
self.unify_impl(a, b, false)
}
/// Unify two types, i.e. a = b. /// Unify two types, i.e. a = b.
pub fn unify(&mut self, mut a: Type, mut b: Type) -> Result<(), String> { fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
let (mut ty_a_cell, mut ty_b_cell) = { use TypeEnum::*;
let (ty_a_cell, ty_b_cell) = {
if self.unification_table.unioned(a, b) { if self.unification_table.unioned(a, b) {
return Ok(()); return Ok(());
} }
@ -240,251 +189,215 @@ impl Unifier {
}; };
let (ty_a, ty_b) = { let (ty_a, ty_b) = {
// simplify our pattern matching...
if ty_a_cell.borrow().type_le(&ty_b_cell.borrow()) {
swap(&mut a, &mut b);
swap(&mut ty_a_cell, &mut ty_b_cell);
}
(ty_a_cell.borrow(), ty_b_cell.borrow()) (ty_a_cell.borrow(), ty_b_cell.borrow())
}; };
self.occur_check(a, b)?; self.occur_check(a, b)?;
match &*ty_a { match (&*ty_a, &*ty_b) {
TypeEnum::TVar { .. } => { (TypeEnum::TVar { .. }, _) => {
// TODO: type variables bound check...
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }
TypeEnum::TSeq { map: map1 } => { (TSeq { map: map1 }, TSeq { .. }) => {
match &*ty_b { drop(ty_b);
TypeEnum::TSeq { .. } => { if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
drop(ty_b); // unify them to map2
if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut() for (key, value) in map1.iter() {
{ if let Some(ty) = map2.get(key) {
// unify them to map2 self.unify(*ty, *value)?;
for (key, value) in map1.iter() {
if let Some(ty) = map2.get(key) {
self.unify(*ty, *value)?;
} else {
map2.insert(*key, *value);
}
}
} else { } else {
unreachable!() map2.insert(*key, *value);
} }
self.set_a_to_b(a, b);
}
TypeEnum::TTuple { ty: types } => {
let len = types.len() as i32;
for (k, v) in map1.iter() {
// handle negative index
let ind = if *k < 0 { len + *k } else { *k };
if ind >= len || ind < 0 {
return Err(format!(
"Tuple index out of range. (Length: {}, Index: {})",
types.len(),
k
));
}
self.unify(*v, types[ind as usize])?;
}
self.set_a_to_b(a, b);
}
TypeEnum::TList { ty } => {
for v in map1.values() {
self.unify(*v, *ty)?;
}
self.set_a_to_b(a, b);
}
_ => {
return self.incompatible_types(&*ty_a, &*ty_b);
} }
} else {
unreachable!()
} }
self.set_a_to_b(a, b);
} }
TypeEnum::TTuple { ty: ty1 } => { (TSeq { map: map1 }, TTuple { ty: types }) => {
if let TypeEnum::TTuple { ty: ty2 } = &*ty_b { let len = types.len() as i32;
if ty1.len() != ty2.len() { for (k, v) in map1.iter() {
// handle negative index
let ind = if *k < 0 { len + *k } else { *k };
if ind >= len || ind < 0 {
return Err(format!( return Err(format!(
"Cannot unify tuples with length {} and {}", "Tuple index out of range. (Length: {}, Index: {})",
ty1.len(), types.len(),
ty2.len() k
)); ));
} }
for (x, y) in ty1.iter().zip(ty2.iter()) { self.unify(*v, types[ind as usize])?;
self.unify(*x, *y)?;
}
self.set_a_to_b(a, b);
} else {
return self.incompatible_types(&*ty_a, &*ty_b);
} }
self.set_a_to_b(a, b);
} }
TypeEnum::TList { ty: ty1 } => { (TSeq { map: map1 }, TList { ty }) => {
if let TypeEnum::TList { ty: ty2 } = *ty_b { for v in map1.values() {
self.unify(*ty1, ty2)?; self.unify(*v, *ty)?;
self.set_a_to_b(a, b);
} else {
return self.incompatible_types(&*ty_a, &*ty_b);
} }
self.set_a_to_b(a, b);
} }
TypeEnum::TRecord { fields: fields1 } => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
match &*ty_b { if ty1.len() != ty2.len() {
TypeEnum::TRecord { .. } => { return Err(format!(
drop(ty_b); "Cannot unify tuples with length {} and {}",
if let TypeEnum::TRecord { fields: fields2 } = ty1.len(),
&mut *ty_b_cell.as_ref().borrow_mut() ty2.len()
{ ));
for (key, value) in fields1.iter() { }
if let Some(ty) = fields2.get(key) { for (x, y) in ty1.iter().zip(ty2.iter()) {
self.unify(*ty, *value)?; self.unify(*x, *y)?;
} else { }
fields2.insert(key.clone(), *value); self.set_a_to_b(a, b);
} }
} (TList { ty: ty1 }, TList { ty: ty2 }) => {
self.unify(*ty1, *ty2)?;
self.set_a_to_b(a, b);
}
(TRecord { fields: fields1 }, TRecord { .. }) => {
drop(ty_b);
if let TypeEnum::TRecord { fields: fields2 } = &mut *ty_b_cell.as_ref().borrow_mut()
{
for (key, value) in fields1.iter() {
if let Some(ty) = fields2.get(key) {
self.unify(*ty, *value)?;
} else { } else {
unreachable!() fields2.insert(key.clone(), *value);
} }
self.set_a_to_b(a, b);
} }
TypeEnum::TObj { } else {
fields: fields2, .. unreachable!()
} => { }
for (key, value) in fields1.iter() { self.set_a_to_b(a, b);
if let Some(ty) = fields2.get(key) { }
self.unify(*ty, *value)?; (
} else { TRecord { fields: fields1 },
return Err(format!("No such attribute {}", key)); TObj {
} fields: fields2, ..
} },
self.set_a_to_b(a, b); ) => {
} for (key, value) in fields1.iter() {
TypeEnum::TVirtual { ty } => { if let Some(ty) = fields2.get(key) {
// not sure if this is correct... self.unify(*ty, *value)?;
self.unify(a, *ty)?; } else {
} return Err(format!("No such attribute {}", key));
_ => {
return self.incompatible_types(&*ty_a, &*ty_b);
} }
} }
self.set_a_to_b(a, b);
} }
TypeEnum::TObj { (TRecord { .. }, TVirtual { ty }) => {
obj_id: id1, self.unify(a, *ty)?;
params: params1, }
.. (
} => { TObj {
if let TypeEnum::TObj { obj_id: id1,
params: params1,
..
},
TObj {
obj_id: id2, obj_id: id2,
params: params2, params: params2,
.. ..
} = &*ty_b },
{ ) => {
if id1 != id2 { if id1 != id2 {
return Err(format!("Cannot unify objects with ID {} and {}", id1, id2)); return Err(format!("Cannot unify objects with ID {} and {}", id1, id2));
}
for (x, y) in params1.values().zip(params2.values()) {
self.unify(*x, *y)?;
}
self.set_a_to_b(a, b);
} else {
return self.incompatible_types(&*ty_a, &*ty_b);
} }
} for (x, y) in params1.values().zip(params2.values()) {
TypeEnum::TVirtual { ty: ty1 } => { self.unify(*x, *y)?;
if let TypeEnum::TVirtual { ty: ty2 } = &*ty_b {
self.unify(*ty1, *ty2)?;
self.set_a_to_b(a, b);
} else {
return self.incompatible_types(&*ty_a, &*ty_b);
} }
self.set_a_to_b(a, b);
} }
TypeEnum::TCall { calls: c1 } => match &*ty_b { (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
TypeEnum::TCall { .. } => { self.unify(*ty1, *ty2)?;
drop(ty_b); self.set_a_to_b(a, b);
if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() { }
c2.extend(c1.iter().cloned()); (TCall { calls: c1 }, TCall { .. }) => {
drop(ty_b);
if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
c2.extend(c1.iter().cloned());
} else {
unreachable!()
}
self.set_a_to_b(a, b);
}
(TCall { calls }, TFunc(signature)) => {
let required: Vec<String> = signature
.args
.iter()
.filter(|v| !v.is_optional)
.map(|v| v.name.clone())
.rev()
.collect();
for c in calls {
let Call {
posargs,
kwargs,
ret,
fun,
} = c.as_ref();
let instantiated = self.instantiate_fun(b, signature);
let signature;
let r = self.get_ty(instantiated);
let r = r.as_ref().borrow();
if let TypeEnum::TFunc(s) = &*r {
signature = s;
} else { } else {
unreachable!() unreachable!();
} }
self.set_a_to_b(a, b); let mut required = required.clone();
} let mut all_names: Vec<_> = signature
TypeEnum::TFunc(signature) => {
let required: Vec<String> = signature
.args .args
.iter() .iter()
.filter(|v| !v.is_optional) .map(|v| (v.name.clone(), v.ty))
.map(|v| v.name.clone())
.rev() .rev()
.collect(); .collect();
for c in c1 { for (i, t) in posargs.iter().enumerate() {
let Call { if signature.args.len() <= i {
posargs, return Err("Too many arguments.".to_string());
kwargs, }
ret, if !required.is_empty() {
fun, required.pop();
} = c.as_ref(); }
let instantiated = self.instantiate_fun(b, signature); self.unify(all_names.pop().unwrap().1, *t)?;
let signature; }
let r = self.get_ty(instantiated); for (k, t) in kwargs.iter() {
let r = r.as_ref().borrow(); if let Some(i) = required.iter().position(|v| v == k) {
if let TypeEnum::TFunc(s) = &*r { required.remove(i);
signature = s; }
if let Some(i) = all_names.iter().position(|v| &v.0 == k) {
self.unify(all_names.remove(i).1, *t)?;
} else { } else {
unreachable!(); return Err(format!("Unknown keyword argument {}", k));
} }
let mut required = required.clone();
let mut all_names: Vec<_> = signature
.args
.iter()
.map(|v| (v.name.clone(), v.ty))
.rev()
.collect();
for (i, t) in posargs.iter().enumerate() {
if signature.args.len() <= i {
return Err("Too many arguments.".to_string());
}
if !required.is_empty() {
required.pop();
}
self.unify(all_names.pop().unwrap().1, *t)?;
}
for (k, t) in kwargs.iter() {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
if let Some(i) = all_names.iter().position(|v| &v.0 == k) {
self.unify(all_names.remove(i).1, *t)?;
} else {
return Err(format!("Unknown keyword argument {}", k));
}
}
self.unify(*ret, signature.ret)?;
*fun.borrow_mut() = Some(instantiated);
} }
self.set_a_to_b(a, b); self.unify(*ret, signature.ret)?;
*fun.borrow_mut() = Some(instantiated);
} }
_ => { self.set_a_to_b(a, b);
}
(TFunc(sign1), TFunc(sign2)) => {
if !sign1.vars.is_empty() || !sign2.vars.is_empty() {
return Err("Polymorphic function pointer is prohibited.".to_string());
}
if sign1.args.len() != sign2.args.len() {
return Err("Functions differ in number of parameters.".to_string());
}
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
if x.name != y.name {
return Err("Functions differ in parameter names.".to_string());
}
if x.is_optional != y.is_optional {
return Err("Functions differ in optional parameters.".to_string());
}
self.unify(x.ty, y.ty)?;
}
self.unify(sign1.ret, sign2.ret)?;
self.set_a_to_b(a, b);
}
_ => {
if swapped {
return self.incompatible_types(&*ty_a, &*ty_b); return self.incompatible_types(&*ty_a, &*ty_b);
}
},
TypeEnum::TFunc(sign1) => {
if let TypeEnum::TFunc(sign2) = &*ty_b {
if !sign1.params.is_empty() || !sign2.params.is_empty() {
return Err("Polymorphic function pointer is prohibited.".to_string());
}
if sign1.args.len() != sign2.args.len() {
return Err("Functions differ in number of parameters.".to_string());
}
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
if x.name != y.name {
return Err("Functions differ in parameter names.".to_string());
}
if x.is_optional != y.is_optional {
return Err("Functions differ in optional parameters.".to_string());
}
self.unify(x.ty, y.ty)?;
}
self.unify(sign1.ret, sign2.ret)?;
self.set_a_to_b(a, b);
} else { } else {
return self.incompatible_types(&*ty_a, &*ty_b); self.unify_impl(b, a, true)?;
} }
} }
} }
@ -555,7 +468,11 @@ impl Unifier {
self.occur_check(a, *t)?; self.occur_check(a, *t)?;
} }
} }
TypeEnum::TFunc(FunSignature { args, ret, params }) => { TypeEnum::TFunc(FunSignature {
args,
ret,
vars: params,
}) => {
for t in args for t in args
.iter() .iter()
.map(|v| &v.ty) .map(|v| &v.ty)
@ -638,7 +555,11 @@ impl Unifier {
None None
} }
} }
TypeEnum::TFunc(FunSignature { args, ret, params }) => { TypeEnum::TFunc(FunSignature {
args,
ret,
vars: params,
}) => {
let new_params = self.subst_map(params, mapping); let new_params = self.subst_map(params, mapping);
let new_ret = self.subst(*ret, mapping); let new_ret = self.subst(*ret, mapping);
let mut new_args = None; let mut new_args = None;
@ -658,7 +579,11 @@ impl Unifier {
let params = new_params.unwrap_or_else(|| params.clone()); let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or_else(|| *ret); let ret = new_ret.unwrap_or_else(|| *ret);
let args = new_args.unwrap_or_else(|| args.clone()); let args = new_args.unwrap_or_else(|| args.clone());
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, params }))) Some(self.add_ty(TypeEnum::TFunc(FunSignature {
args,
ret,
vars: params,
})))
} else { } else {
None None
} }
@ -688,7 +613,7 @@ impl Unifier {
/// Returns None if the function is already instantiated. /// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = false; let mut instantiated = false;
for (k, v) in fun.params.iter() { for (k, v) in fun.vars.iter() {
if let TypeEnum::TVar { id } = if let TypeEnum::TVar { id } =
&*self.unification_table.probe_value(*v).as_ref().borrow() &*self.unification_table.probe_value(*v).as_ref().borrow()
{ {
@ -705,7 +630,7 @@ impl Unifier {
ty ty
} else { } else {
let mapping = fun let mapping = fun
.params .vars
.iter() .iter()
.map(|(k, _)| (*k, self.get_fresh_var().0)) .map(|(k, _)| (*k, self.get_fresh_var().0))
.collect(); .collect();