hm-inference #6
|
@ -206,7 +206,7 @@ mod test {
|
||||||
("v1", "Tuple[int]"),
|
("v1", "Tuple[int]"),
|
||||||
("v2", "List[int]"),
|
("v2", "List[int]"),
|
||||||
],
|
],
|
||||||
(("v1", "v2"), "Cannot unify TTuple with TList")
|
(("v1", "v2"), "Cannot unify TList with TTuple")
|
||||||
; "type mismatch"
|
; "type mismatch"
|
||||||
)]
|
)]
|
||||||
#[test_case(2,
|
#[test_case(2,
|
||||||
|
@ -222,7 +222,7 @@ mod test {
|
||||||
("v1", "Tuple[int,int]"),
|
("v1", "Tuple[int,int]"),
|
||||||
("v2", "Tuple[int]"),
|
("v2", "Tuple[int]"),
|
||||||
],
|
],
|
||||||
(("v1", "v2"), "Cannot unify tuples with length 1 and 2")
|
(("v1", "v2"), "Cannot unify tuples with length 2 and 1")
|
||||||
; "tuple length mismatch"
|
; "tuple length mismatch"
|
||||||
)]
|
)]
|
||||||
#[test_case(3,
|
#[test_case(3,
|
||||||
|
|
|
@ -3,7 +3,6 @@ use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
use std::mem::swap;
|
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
@ -69,7 +68,7 @@ pub struct FuncArg {
|
||||||
pub struct FunSignature {
|
pub struct FunSignature {
|
||||||
pub args: Vec<FuncArg>,
|
pub args: Vec<FuncArg>,
|
||||||
pub ret: Type,
|
pub ret: Type,
|
||||||
pub params: VarMap,
|
pub vars: VarMap,
|
||||||
}
|
}
|
||||||
|
|
||||||
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
|
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
|
||||||
|
@ -117,62 +116,7 @@ pub enum TypeEnum {
|
||||||
// `--> TCall
|
// `--> TCall
|
||||||
// `--> TFunc
|
// `--> TFunc
|
||||||
|
|
||||||
// We encode the types as natural numbers, and subtyping relation as divisibility.
|
|
||||||
// If a | b, b <: a.
|
|
||||||
// We assign unique prime numbers (1 to TVar, everything is a subtype of it) to each type:
|
|
||||||
// TVar = 1
|
|
||||||
// |--> TSeq = 2
|
|
||||||
// | |--> TTuple = 3
|
|
||||||
// | `--> TList = 5
|
|
||||||
// |--> TRecord = 7
|
|
||||||
// | |--> TObj = 11
|
|
||||||
// | `--> TVirtual = 13
|
|
||||||
// `--> TCall = 17
|
|
||||||
// `--> TFunc = 21
|
|
||||||
//
|
|
||||||
// And then, based on the subtyping relation, multiply them together...
|
|
||||||
// TVar = 1
|
|
||||||
// |--> TSeq = 2 * TVar
|
|
||||||
// | |--> TTuple = 3 * TSeq * TVar
|
|
||||||
// | `--> TList = 5 * TSeq * TVar
|
|
||||||
// |--> TRecord = 7 * TVar
|
|
||||||
// | |--> TObj = 11 * TRecord * TVar
|
|
||||||
// | `--> TVirtual = 13 * TRecord * TVar
|
|
||||||
// `--> TCall = 17 * TVar
|
|
||||||
// `--> TFunc = 21 * TCall * TVar
|
|
||||||
|
|
||||||
impl TypeEnum {
|
impl TypeEnum {
|
||||||
fn get_int(&self) -> i32 {
|
|
||||||
const TVAR: i32 = 1;
|
|
||||||
const TSEQ: i32 = 2;
|
|
||||||
const TTUPLE: i32 = 3;
|
|
||||||
const TLIST: i32 = 5;
|
|
||||||
const TRECORD: i32 = 7;
|
|
||||||
const TOBJ: i32 = 11;
|
|
||||||
const TVIRTUAL: i32 = 13;
|
|
||||||
const TCALL: i32 = 17;
|
|
||||||
const TFUNC: i32 = 21;
|
|
||||||
|
|
||||||
match self {
|
|
||||||
TypeEnum::TVar { .. } => TVAR,
|
|
||||||
TypeEnum::TSeq { .. } => TSEQ * TVAR,
|
|
||||||
TypeEnum::TTuple { .. } => TTUPLE * TSEQ * TVAR,
|
|
||||||
TypeEnum::TList { .. } => TLIST * TSEQ * TVAR,
|
|
||||||
TypeEnum::TRecord { .. } => TRECORD * TVAR,
|
|
||||||
TypeEnum::TObj { .. } => TOBJ * TRECORD * TVAR,
|
|
||||||
TypeEnum::TVirtual { .. } => TVIRTUAL * TRECORD * TVAR,
|
|
||||||
TypeEnum::TCall { .. } => TCALL * TVAR,
|
|
||||||
TypeEnum::TFunc { .. } => TFUNC * TCALL * TVAR,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// e.g. List <: Var
|
|
||||||
pub fn type_le(&self, other: &TypeEnum) -> bool {
|
|
||||||
let a = self.get_int();
|
|
||||||
let b = other.get_int();
|
|
||||||
(a % b) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_type_name(&self) -> &'static str {
|
pub fn get_type_name(&self) -> &'static str {
|
||||||
// this function is for debugging only...
|
// this function is for debugging only...
|
||||||
// a proper to_str implementation requires the context
|
// a proper to_str implementation requires the context
|
||||||
|
@ -227,9 +171,14 @@ impl Unifier {
|
||||||
self.unification_table.probe_value(a).0
|
self.unification_table.probe_value(a).0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||||
|
self.unify_impl(a, b, false)
|
||||||
|
}
|
||||||
|
|
||||||
/// Unify two types, i.e. a = b.
|
/// Unify two types, i.e. a = b.
|
||||||
pub fn unify(&mut self, mut a: Type, mut b: Type) -> Result<(), String> {
|
fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
|
||||||
let (mut ty_a_cell, mut ty_b_cell) = {
|
use TypeEnum::*;
|
||||||
|
let (ty_a_cell, ty_b_cell) = {
|
||||||
if self.unification_table.unioned(a, b) {
|
if self.unification_table.unioned(a, b) {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
@ -240,26 +189,17 @@ impl Unifier {
|
||||||
};
|
};
|
||||||
|
|
||||||
let (ty_a, ty_b) = {
|
let (ty_a, ty_b) = {
|
||||||
// simplify our pattern matching...
|
|
||||||
if ty_a_cell.borrow().type_le(&ty_b_cell.borrow()) {
|
|
||||||
swap(&mut a, &mut b);
|
|
||||||
swap(&mut ty_a_cell, &mut ty_b_cell);
|
|
||||||
}
|
|
||||||
(ty_a_cell.borrow(), ty_b_cell.borrow())
|
(ty_a_cell.borrow(), ty_b_cell.borrow())
|
||||||
};
|
};
|
||||||
|
|
||||||
self.occur_check(a, b)?;
|
self.occur_check(a, b)?;
|
||||||
match &*ty_a {
|
match (&*ty_a, &*ty_b) {
|
||||||
TypeEnum::TVar { .. } => {
|
(TypeEnum::TVar { .. }, _) => {
|
||||||
// TODO: type variables bound check...
|
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
TypeEnum::TSeq { map: map1 } => {
|
(TSeq { map: map1 }, TSeq { .. }) => {
|
||||||
match &*ty_b {
|
|
||||||
TypeEnum::TSeq { .. } => {
|
|
||||||
drop(ty_b);
|
drop(ty_b);
|
||||||
if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut()
|
if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
|
||||||
{
|
|
||||||
// unify them to map2
|
// unify them to map2
|
||||||
for (key, value) in map1.iter() {
|
for (key, value) in map1.iter() {
|
||||||
if let Some(ty) = map2.get(key) {
|
if let Some(ty) = map2.get(key) {
|
||||||
|
@ -273,7 +213,7 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { ty: types } => {
|
(TSeq { map: map1 }, TTuple { ty: types }) => {
|
||||||
let len = types.len() as i32;
|
let len = types.len() as i32;
|
||||||
for (k, v) in map1.iter() {
|
for (k, v) in map1.iter() {
|
||||||
// handle negative index
|
// handle negative index
|
||||||
|
@ -289,19 +229,13 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
TypeEnum::TList { ty } => {
|
(TSeq { map: map1 }, TList { ty }) => {
|
||||||
for v in map1.values() {
|
for v in map1.values() {
|
||||||
self.unify(*v, *ty)?;
|
self.unify(*v, *ty)?;
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
_ => {
|
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TypeEnum::TTuple { ty: ty1 } => {
|
|
||||||
if let TypeEnum::TTuple { ty: ty2 } = &*ty_b {
|
|
||||||
if ty1.len() != ty2.len() {
|
if ty1.len() != ty2.len() {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Cannot unify tuples with length {} and {}",
|
"Cannot unify tuples with length {} and {}",
|
||||||
|
@ -313,24 +247,14 @@ impl Unifier {
|
||||||
self.unify(*x, *y)?;
|
self.unify(*x, *y)?;
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
} else {
|
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
|
||||||
}
|
}
|
||||||
}
|
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||||
TypeEnum::TList { ty: ty1 } => {
|
self.unify(*ty1, *ty2)?;
|
||||||
if let TypeEnum::TList { ty: ty2 } = *ty_b {
|
|
||||||
self.unify(*ty1, ty2)?;
|
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
} else {
|
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
|
||||||
}
|
}
|
||||||
}
|
(TRecord { fields: fields1 }, TRecord { .. }) => {
|
||||||
TypeEnum::TRecord { fields: fields1 } => {
|
|
||||||
match &*ty_b {
|
|
||||||
TypeEnum::TRecord { .. } => {
|
|
||||||
drop(ty_b);
|
drop(ty_b);
|
||||||
if let TypeEnum::TRecord { fields: fields2 } =
|
if let TypeEnum::TRecord { fields: fields2 } = &mut *ty_b_cell.as_ref().borrow_mut()
|
||||||
&mut *ty_b_cell.as_ref().borrow_mut()
|
|
||||||
{
|
{
|
||||||
for (key, value) in fields1.iter() {
|
for (key, value) in fields1.iter() {
|
||||||
if let Some(ty) = fields2.get(key) {
|
if let Some(ty) = fields2.get(key) {
|
||||||
|
@ -344,9 +268,12 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
TypeEnum::TObj {
|
(
|
||||||
|
TRecord { fields: fields1 },
|
||||||
|
TObj {
|
||||||
fields: fields2, ..
|
fields: fields2, ..
|
||||||
} => {
|
},
|
||||||
|
) => {
|
||||||
for (key, value) in fields1.iter() {
|
for (key, value) in fields1.iter() {
|
||||||
if let Some(ty) = fields2.get(key) {
|
if let Some(ty) = fields2.get(key) {
|
||||||
self.unify(*ty, *value)?;
|
self.unify(*ty, *value)?;
|
||||||
|
@ -356,26 +283,21 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
TypeEnum::TVirtual { ty } => {
|
(TRecord { .. }, TVirtual { ty }) => {
|
||||||
// not sure if this is correct...
|
|
||||||
self.unify(a, *ty)?;
|
self.unify(a, *ty)?;
|
||||||
}
|
}
|
||||||
_ => {
|
(
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
TObj {
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TypeEnum::TObj {
|
|
||||||
obj_id: id1,
|
obj_id: id1,
|
||||||
params: params1,
|
params: params1,
|
||||||
..
|
..
|
||||||
} => {
|
},
|
||||||
if let TypeEnum::TObj {
|
TObj {
|
||||||
obj_id: id2,
|
obj_id: id2,
|
||||||
params: params2,
|
params: params2,
|
||||||
..
|
..
|
||||||
} = &*ty_b
|
},
|
||||||
{
|
) => {
|
||||||
if id1 != id2 {
|
if id1 != id2 {
|
||||||
return Err(format!("Cannot unify objects with ID {} and {}", id1, id2));
|
return Err(format!("Cannot unify objects with ID {} and {}", id1, id2));
|
||||||
}
|
}
|
||||||
|
@ -383,20 +305,12 @@ impl Unifier {
|
||||||
self.unify(*x, *y)?;
|
self.unify(*x, *y)?;
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
} else {
|
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
|
||||||
}
|
}
|
||||||
}
|
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||||
TypeEnum::TVirtual { ty: ty1 } => {
|
|
||||||
if let TypeEnum::TVirtual { ty: ty2 } = &*ty_b {
|
|
||||||
self.unify(*ty1, *ty2)?;
|
self.unify(*ty1, *ty2)?;
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
} else {
|
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
|
||||||
}
|
}
|
||||||
}
|
(TCall { calls: c1 }, TCall { .. }) => {
|
||||||
TypeEnum::TCall { calls: c1 } => match &*ty_b {
|
|
||||||
TypeEnum::TCall { .. } => {
|
|
||||||
drop(ty_b);
|
drop(ty_b);
|
||||||
if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
|
if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
|
||||||
c2.extend(c1.iter().cloned());
|
c2.extend(c1.iter().cloned());
|
||||||
|
@ -405,7 +319,7 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
TypeEnum::TFunc(signature) => {
|
(TCall { calls }, TFunc(signature)) => {
|
||||||
let required: Vec<String> = signature
|
let required: Vec<String> = signature
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -413,7 +327,7 @@ impl Unifier {
|
||||||
.map(|v| v.name.clone())
|
.map(|v| v.name.clone())
|
||||||
.rev()
|
.rev()
|
||||||
.collect();
|
.collect();
|
||||||
for c in c1 {
|
for c in calls {
|
||||||
let Call {
|
let Call {
|
||||||
posargs,
|
posargs,
|
||||||
kwargs,
|
kwargs,
|
||||||
|
@ -460,13 +374,8 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
_ => {
|
(TFunc(sign1), TFunc(sign2)) => {
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
if !sign1.vars.is_empty() || !sign2.vars.is_empty() {
|
||||||
}
|
|
||||||
},
|
|
||||||
TypeEnum::TFunc(sign1) => {
|
|
||||||
if let TypeEnum::TFunc(sign2) = &*ty_b {
|
|
||||||
if !sign1.params.is_empty() || !sign2.params.is_empty() {
|
|
||||||
return Err("Polymorphic function pointer is prohibited.".to_string());
|
return Err("Polymorphic function pointer is prohibited.".to_string());
|
||||||
}
|
}
|
||||||
if sign1.args.len() != sign2.args.len() {
|
if sign1.args.len() != sign2.args.len() {
|
||||||
|
@ -483,8 +392,12 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.unify(sign1.ret, sign2.ret)?;
|
self.unify(sign1.ret, sign2.ret)?;
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
} else {
|
}
|
||||||
|
_ => {
|
||||||
|
if swapped {
|
||||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
return self.incompatible_types(&*ty_a, &*ty_b);
|
||||||
|
} else {
|
||||||
|
self.unify_impl(b, a, true)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -555,7 +468,11 @@ impl Unifier {
|
||||||
self.occur_check(a, *t)?;
|
self.occur_check(a, *t)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TFunc(FunSignature { args, ret, params }) => {
|
TypeEnum::TFunc(FunSignature {
|
||||||
|
args,
|
||||||
|
ret,
|
||||||
|
vars: params,
|
||||||
|
}) => {
|
||||||
for t in args
|
for t in args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| &v.ty)
|
.map(|v| &v.ty)
|
||||||
|
@ -638,7 +555,11 @@ impl Unifier {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TFunc(FunSignature { args, ret, params }) => {
|
TypeEnum::TFunc(FunSignature {
|
||||||
|
args,
|
||||||
|
ret,
|
||||||
|
vars: params,
|
||||||
|
}) => {
|
||||||
let new_params = self.subst_map(params, mapping);
|
let new_params = self.subst_map(params, mapping);
|
||||||
let new_ret = self.subst(*ret, mapping);
|
let new_ret = self.subst(*ret, mapping);
|
||||||
let mut new_args = None;
|
let mut new_args = None;
|
||||||
|
@ -658,7 +579,11 @@ impl Unifier {
|
||||||
let params = new_params.unwrap_or_else(|| params.clone());
|
let params = new_params.unwrap_or_else(|| params.clone());
|
||||||
let ret = new_ret.unwrap_or_else(|| *ret);
|
let ret = new_ret.unwrap_or_else(|| *ret);
|
||||||
let args = new_args.unwrap_or_else(|| args.clone());
|
let args = new_args.unwrap_or_else(|| args.clone());
|
||||||
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, params })))
|
Some(self.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args,
|
||||||
|
ret,
|
||||||
|
vars: params,
|
||||||
|
})))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -688,7 +613,7 @@ impl Unifier {
|
||||||
/// Returns None if the function is already instantiated.
|
/// Returns None if the function is already instantiated.
|
||||||
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
||||||
let mut instantiated = false;
|
let mut instantiated = false;
|
||||||
for (k, v) in fun.params.iter() {
|
for (k, v) in fun.vars.iter() {
|
||||||
if let TypeEnum::TVar { id } =
|
if let TypeEnum::TVar { id } =
|
||||||
&*self.unification_table.probe_value(*v).as_ref().borrow()
|
&*self.unification_table.probe_value(*v).as_ref().borrow()
|
||||||
{
|
{
|
||||||
|
@ -705,7 +630,7 @@ impl Unifier {
|
||||||
ty
|
ty
|
||||||
} else {
|
} else {
|
||||||
let mapping = fun
|
let mapping = fun
|
||||||
.params
|
.vars
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(k, _)| (*k, self.get_fresh_var().0))
|
.map(|(k, _)| (*k, self.get_fresh_var().0))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
Loading…
Reference in New Issue