use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; use std::cell::RefCell; use std::collections::HashMap; use std::fmt::Debug; use std::iter::once; use std::mem::swap; use std::ops::Deref; use std::rc::Rc; #[derive(Copy, Clone, PartialEq, Eq, Debug)] /// Handle for a type, implementated as a key in the unification table. pub struct Type(u32); #[derive(Clone)] pub struct TypeCell(Rc>); impl UnifyValue for TypeCell { type Error = NoError; fn unify_values(_: &Self, value2: &Self) -> Result { // WARN: depends on the implementation details of ena. // We do not use this to do unification, instead we perform unification // and assign the type by `union_value(key, new_value)`, which set the // value as `unify_values(key.value, new_value)`. So, we need to return // the right one. Ok(value2.clone()) } } impl UnifyKey for Type { type Value = TypeCell; fn index(&self) -> u32 { self.0 } fn from_index(u: u32) -> Self { Type(u) } fn tag() -> &'static str { "TypeID" } } impl Deref for TypeCell { type Target = Rc>; fn deref(&self) -> &::Target { &self.0 } } pub type Mapping = HashMap; pub type VarMap = Mapping; #[derive(Clone)] pub struct Call { posargs: Vec, kwargs: HashMap, ret: Type, fun: RefCell>, } #[derive(Clone)] pub struct FuncArg { name: String, ty: Type, is_optional: bool, } // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. // We may not really need so much `Rc`s, but we would have to do complicated // stuffs otherwise. pub enum TypeEnum { TVar { // TODO: upper/lower bound id: u32, }, TSeq { map: VarMap, }, TTuple { ty: Vec, }, TList { ty: Type, }, TRecord { fields: Mapping, }, TObj { obj_id: usize, fields: Mapping, params: VarMap, }, TVirtual { ty: Type, }, TCall { calls: Vec>, }, TFunc { args: Vec, ret: Type, params: VarMap, }, } // Order: // TVar // |--> TSeq // | |--> TTuple // | `--> TList // |--> TRecord // | |--> TObj // | `--> TVirtual // `--> TCall // `--> TFunc // We encode the types as natural numbers, and subtyping relation as divisibility. // If a | b, b <: a. // We assign unique prime numbers (1 to TVar, everything is a subtype of it) to each type: // TVar = 1 // |--> TSeq = 2 // | |--> TTuple = 3 // | `--> TList = 5 // |--> TRecord = 7 // | |--> TObj = 11 // | `--> TVirtual = 13 // `--> TCall = 17 // `--> TFunc = 21 // // And then, based on the subtyping relation, multiply them together... // TVar = 1 // |--> TSeq = 2 * TVar // | |--> TTuple = 3 * TSeq * TVar // | `--> TList = 5 * TSeq * TVar // |--> TRecord = 7 * TVar // | |--> TObj = 11 * TRecord * TVar // | `--> TVirtual = 13 * TRecord * TVar // `--> TCall = 17 * TVar // `--> TFunc = 21 * TCall * TVar impl TypeEnum { fn get_int(&self) -> i32 { const TVAR: i32 = 1; const TSEQ: i32 = 2; const TTUPLE: i32 = 3; const TLIST: i32 = 5; const TRECORD: i32 = 7; const TOBJ: i32 = 11; const TVIRTUAL: i32 = 13; const TCALL: i32 = 17; const TFUNC: i32 = 21; match self { TypeEnum::TVar { .. } => TVAR, TypeEnum::TSeq { .. } => TSEQ * TVAR, TypeEnum::TTuple { .. } => TTUPLE * TSEQ * TVAR, TypeEnum::TList { .. } => TLIST * TSEQ * TVAR, TypeEnum::TRecord { .. } => TRECORD * TVAR, TypeEnum::TObj { .. } => TOBJ * TRECORD * TVAR, TypeEnum::TVirtual { .. } => TVIRTUAL * TRECORD * TVAR, TypeEnum::TCall { .. } => TCALL * TVAR, TypeEnum::TFunc { .. } => TFUNC * TCALL * TVAR, } } // e.g. List <: Var pub fn kind_le(&self, other: &TypeEnum) -> bool { let a = self.get_int(); let b = other.get_int(); (a % b) == 0 } pub fn get_type_name(&self) -> &'static str { // this function is for debugging only... // a proper to_str implementation requires the context match self { TypeEnum::TVar { .. } => "TVar", TypeEnum::TSeq { .. } => "TSeq", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", TypeEnum::TRecord { .. } => "TRecord", TypeEnum::TObj { .. } => "TObj", TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", TypeEnum::TFunc { .. } => "TFunc", } } } impl Debug for TypeCell { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(&self.borrow().get_type_name()) } } pub struct ObjDef { name: String, fields: Mapping, } pub struct Unifier { unification_table: RefCell>, obj_def_table: Vec, } impl Unifier { pub fn new() -> Unifier { Unifier { unification_table: RefCell::new(InPlaceUnificationTable::new()), obj_def_table: Vec::new(), } } /// Register a type to the unifier. /// Returns a key in the unification_table. pub fn add_ty(&self, a: TypeEnum) -> Type { self.unification_table .borrow_mut() .new_key(TypeCell(Rc::new(a.into()))) } /// Get the TypeEnum of a type. pub fn get_ty(&self, a: Type) -> Rc> { let mut table = self.unification_table.borrow_mut(); table.probe_value(a).0 } pub fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> { let (mut ty_a_cell, mut ty_b_cell) = { let mut table = self.unification_table.borrow_mut(); if table.unioned(a, b) { return Ok(()); } (table.probe_value(a), table.probe_value(b)) }; let (ty_a, ty_b) = { // simplify our pattern matching... if ty_a_cell.borrow().kind_le(&ty_b_cell.borrow()) { swap(&mut a, &mut b); swap(&mut ty_a_cell, &mut ty_b_cell); } (ty_a_cell.borrow(), ty_b_cell.borrow()) }; self.occur_check(a, b)?; match &*ty_a { TypeEnum::TVar { .. } => { // TODO: type variables bound check... self.set_a_to_b(a, b); } TypeEnum::TSeq { map: map1 } => { match &*ty_b { TypeEnum::TSeq { .. } => { drop(ty_b); if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut() { // unify them to map2 for (key, value) in map1.iter() { if let Some(ty) = map2.get(key) { self.unify(*ty, *value)?; } else { map2.insert(*key, *value); } } } else { unreachable!() } self.set_a_to_b(a, b); } TypeEnum::TTuple { ty: types } => { let len = types.len() as u32; for (k, v) in map1.iter() { if *k >= len { return Err(format!( "Tuple index out of range. (Length: {}, Index: {})", types.len(), k )); } self.unify(*v, types[*k as usize])?; } self.set_a_to_b(a, b); } TypeEnum::TList { ty } => { for v in map1.values() { self.unify(*v, *ty)?; } self.set_a_to_b(a, b); } _ => { return self.report_kind_error(&*ty_a, &*ty_b); } } } TypeEnum::TTuple { ty: ty1 } => { if let TypeEnum::TTuple { ty: ty2 } = &*ty_b { if ty1.len() != ty2.len() { return Err(format!( "Cannot unify tuples with length {} and {}", ty1.len(), ty2.len() )); } for (x, y) in ty1.iter().zip(ty2.iter()) { self.unify(*x, *y)?; } self.set_a_to_b(a, b); } else { return self.report_kind_error(&*ty_a, &*ty_b); } } TypeEnum::TList { ty: ty1 } => { if let TypeEnum::TList { ty: ty2 } = *ty_b { self.unify(*ty1, ty2)?; self.set_a_to_b(a, b); } else { return self.report_kind_error(&*ty_a, &*ty_b); } } TypeEnum::TRecord { fields: fields1 } => { match &*ty_b { TypeEnum::TRecord { .. } => { drop(ty_b); if let TypeEnum::TRecord { fields: fields2 } = &mut *ty_b_cell.as_ref().borrow_mut() { for (key, value) in fields1.iter() { if let Some(ty) = fields2.get(key) { self.unify(*ty, *value)?; } else { fields2.insert(key.clone(), *value); } } } else { unreachable!() } self.set_a_to_b(a, b); } TypeEnum::TObj { fields: fields2, .. } => { for (key, value) in fields1.iter() { if let Some(ty) = fields2.get(key) { self.unify(*ty, *value)?; } else { return Err(format!("No such attribute {}", key)); } } self.set_a_to_b(a, b); } TypeEnum::TVirtual { ty } => { // not sure if this is correct... self.unify(a, *ty)?; self.set_a_to_b(a, b); } _ => { return self.report_kind_error(&*ty_a, &*ty_b); } } } TypeEnum::TObj { obj_id: id1, params: params1, .. } => { if let TypeEnum::TObj { obj_id: id2, params: params2, .. } = &*ty_b { if id1 != id2 { return Err(format!("Cannot unify objects with ID {} and {}", id1, id2)); } for (x, y) in params1.values().zip(params2.values()) { self.unify(*x, *y)?; } self.set_a_to_b(a, b); } else { return self.report_kind_error(&*ty_a, &*ty_b); } } TypeEnum::TVirtual { ty: ty1 } => { if let TypeEnum::TVirtual { ty: ty2 } = &*ty_b { self.unify(*ty1, *ty2)?; self.set_a_to_b(a, b); } else { return self.report_kind_error(&*ty_a, &*ty_b); } } _ => unimplemented!(), } Ok(()) } fn set_a_to_b(&self, a: Type, b: Type) { // unify a and b together, and set the value to b's value. let mut table = self.unification_table.borrow_mut(); let ty_b = table.probe_value(b); table.union(a, b); table.union_value(a, ty_b); } fn report_kind_error(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> { Err(format!( "Cannot unify {} with {}", a.get_type_name(), b.get_type_name() )) } fn occur_check(&self, a: Type, b: Type) -> Result<(), String> { if self.unification_table.borrow_mut().unioned(a, b) { return Err("Recursive type is prohibited.".to_owned()); } let ty = self.unification_table.borrow_mut().probe_value(b); let ty = ty.borrow(); match &*ty { TypeEnum::TVar { .. } => { // TODO: occur check for bounds... } TypeEnum::TSeq { map } | TypeEnum::TObj { params: map, .. } => { for t in map.values() { self.occur_check(a, *t)?; } } TypeEnum::TTuple { ty } => { for t in ty.iter() { self.occur_check(a, *t)?; } } TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => { self.occur_check(a, *ty)?; } TypeEnum::TRecord { fields } => { for t in fields.values() { self.occur_check(a, *t)?; } } TypeEnum::TCall { calls } => { for t in calls .iter() .map(|call| { call.posargs .iter() .chain(call.kwargs.values()) .chain(once(&call.ret)) }) .flatten() { self.occur_check(a, *t)?; } } TypeEnum::TFunc { args, ret, params } => { for t in args .iter() .map(|v| &v.ty) .chain(params.values()) .chain(once(ret)) { self.occur_check(a, *t)?; } } }; Ok(()) } pub fn subst(&self, a: Type, mapping: &VarMap) -> Option { let ty_cell = self.unification_table.borrow_mut().probe_value(a); let ty = ty_cell.borrow(); // this function would only be called when we instantiate functions. // function type signature should ONLY contain concrete types and type // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { TypeEnum::TVar { id } => mapping.get(&id).cloned(), TypeEnum::TSeq { map } => self .subst_map(map, mapping) .map(|m| self.add_ty(TypeEnum::TSeq { map: m })), TypeEnum::TTuple { ty } => { let mut new_ty = None; for (i, t) in ty.iter().enumerate() { if let Some(t1) = self.subst(*t, mapping) { if new_ty.is_none() { new_ty = Some(ty.clone()); } new_ty.as_mut().unwrap()[i] = t1; } } new_ty.map(|t| self.add_ty(TypeEnum::TTuple { ty: t })) } TypeEnum::TList { ty } => self .subst(*ty, mapping) .map(|t| self.add_ty(TypeEnum::TList { ty: t })), TypeEnum::TVirtual { ty } => self .subst(*ty, mapping) .map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })), TypeEnum::TObj { obj_id, fields, params, } => { // Type variables in field types must be present in the type parameter. // If the mapping does not contain any type variables in the // parameter list, we don't need to substitute the fields. // This is also used to prevent infinite substitution... let need_subst = params.values().any(|v| { let ty_cell = self.unification_table.borrow_mut().probe_value(*v); let ty = ty_cell.borrow(); if let TypeEnum::TVar { id } = &*ty { mapping.contains_key(id) } else { false } }); if need_subst { let obj_id = *obj_id; let params = self .subst_map(¶ms, mapping) .unwrap_or_else(|| params.clone()); let fields = self .subst_map(&fields, mapping) .unwrap_or_else(|| fields.clone()); Some(self.add_ty(TypeEnum::TObj { obj_id, params, fields, })) } else { None } } TypeEnum::TFunc { args, ret, params } => { let new_params = self.subst_map(params, mapping); let new_ret = self.subst(*ret, mapping); let mut new_args = None; for (i, t) in args.iter().enumerate() { if let Some(t1) = self.subst(t.ty, mapping) { if new_args.is_none() { new_args = Some(args.clone()); } new_args.as_mut().unwrap()[i] = FuncArg { name: t.name.clone(), ty: t1, is_optional: t.is_optional, }; } } if new_params.is_some() || new_ret.is_some() || new_args.is_some() { let params = new_params.unwrap_or_else(|| params.clone()); let ret = new_ret.unwrap_or_else(|| *ret); let args = new_args.unwrap_or_else(|| args.clone()); Some(self.add_ty(TypeEnum::TFunc { params, ret, args })) } else { None } } _ => unimplemented!(), } } fn subst_map(&self, map: &Mapping, mapping: &VarMap) -> Option> where K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, { let mut map2 = None; for (k, v) in map.iter() { if let Some(v1) = self.subst(*v, mapping) { if map2.is_none() { map2 = Some(map.clone()); } *map2.as_mut().unwrap().get_mut(k).unwrap() = v1; } } map2 } pub fn eq(&self, a: Type, b: Type) -> bool { if a == b { return true; } let (ty_a, ty_b) = { let mut table = self.unification_table.borrow_mut(); if table.unioned(a, b) { return true; } (table.probe_value(a), table.probe_value(b)) }; let ty_a = ty_a.borrow(); let ty_b = ty_b.borrow(); match (&*ty_a, &*ty_b) { (TypeEnum::TVar { id: id1 }, TypeEnum::TVar { id: id2 }) => id1 == id2, (TypeEnum::TSeq { map: map1 }, TypeEnum::TSeq { map: map2 }) => self.map_eq(map1, map2), (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => { ty1.len() == ty2.len() && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) } (TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 }) | (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => { self.eq(*ty1, *ty2) } (TypeEnum::TRecord { fields: fields1 }, TypeEnum::TRecord { fields: fields2 }) => { self.map_eq(fields1, fields2) } ( TypeEnum::TObj { obj_id: id1, params: params1, .. }, TypeEnum::TObj { obj_id: id2, params: params2, .. }, ) => id1 == id2 && self.map_eq(params1, params2), // TCall and TFunc are not yet implemented _ => false, } } fn map_eq(&self, map1: &Mapping, map2: &Mapping) -> bool where K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, { if map1.len() != map2.len() { return false; } for (k, v) in map1.iter() { if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) { return false; } } true } }