hm-inference #6
|
@ -8,14 +8,12 @@ mod test {
|
|||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
type_mapping: HashMap<String, Type>,
|
||||
var_max_id: u32,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
fn new() -> TestEnvironment {
|
||||
let unifier = Unifier::new();
|
||||
let mut unifier = Unifier::new();
|
||||
let mut type_mapping = HashMap::new();
|
||||
let mut var_max_id = 0;
|
||||
|
||||
type_mapping.insert(
|
||||
"int".into(),
|
||||
|
@ -41,38 +39,30 @@ mod test {
|
|||
params: HashMap::new(),
|
||||
}),
|
||||
);
|
||||
let v0 = unifier.add_ty(TypeEnum::TVar { id: 0 });
|
||||
var_max_id += 1;
|
||||
let (v0, id) = unifier.get_fresh_var();
|
||||
type_mapping.insert(
|
||||
"Foo".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 3,
|
||||
fields: [("a".into(), v0)].iter().cloned().collect(),
|
||||
params: [(0u32, v0)].iter().cloned().collect(),
|
||||
params: [(id, v0)].iter().cloned().collect(),
|
||||
}),
|
||||
);
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
type_mapping,
|
||||
var_max_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_fresh_var(&mut self) -> Type {
|
||||
let id = self.var_max_id + 1;
|
||||
self.var_max_id += 1;
|
||||
self.unifier.add_ty(TypeEnum::TVar { id })
|
||||
}
|
||||
|
||||
fn parse(&self, typ: &str, mapping: &Mapping<String>) -> Type {
|
||||
fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
|
||||
let result = self.internal_parse(typ, mapping);
|
||||
assert!(result.1.is_empty());
|
||||
result.0
|
||||
}
|
||||
|
||||
fn internal_parse<'a, 'b>(
|
||||
&'a self,
|
||||
&'a mut self,
|
||||
typ: &'b str,
|
||||
mapping: &Mapping<String>,
|
||||
) -> (Type, &'b str) {
|
||||
|
@ -189,8 +179,8 @@ mod test {
|
|||
let mut env = TestEnvironment::new();
|
||||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v);
|
||||
let v = env.unifier.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v.0);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
|
@ -259,8 +249,8 @@ mod test {
|
|||
let mut env = TestEnvironment::new();
|
||||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v);
|
||||
let v = env.unifier.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v.0);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
|
|
|
@ -48,7 +48,7 @@ impl Deref for TypeCell {
|
|||
}
|
||||
|
||||
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
||||
pub type VarMap = Mapping<u32>;
|
||||
type VarMap = Mapping<u32>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Call {
|
||||
|
@ -65,6 +65,13 @@ pub struct FuncArg {
|
|||
is_optional: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FunSignature {
|
||||
args: Vec<FuncArg>,
|
||||
ret: Type,
|
||||
params: VarMap,
|
||||
}
|
||||
|
||||
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
|
||||
// We may not really need so much `Rc`s, but we would have to do complicated
|
||||
// stuffs otherwise.
|
||||
|
@ -96,11 +103,7 @@ pub enum TypeEnum {
|
|||
TCall {
|
||||
calls: Vec<Rc<Call>>,
|
||||
},
|
||||
TFunc {
|
||||
args: Vec<FuncArg>,
|
||||
ret: Type,
|
||||
params: VarMap,
|
||||
},
|
||||
TFunc(FunSignature),
|
||||
}
|
||||
|
||||
// Order:
|
||||
|
@ -199,40 +202,41 @@ pub struct ObjDef {
|
|||
}
|
||||
|
||||
pub struct Unifier {
|
||||
unification_table: RefCell<InPlaceUnificationTable<Type>>,
|
||||
unification_table: InPlaceUnificationTable<Type>,
|
||||
obj_def_table: Vec<ObjDef>,
|
||||
var_id: u32,
|
||||
}
|
||||
|
||||
impl Unifier {
|
||||
pub fn new() -> Unifier {
|
||||
Unifier {
|
||||
unification_table: RefCell::new(InPlaceUnificationTable::new()),
|
||||
unification_table: InPlaceUnificationTable::new(),
|
||||
obj_def_table: Vec::new(),
|
||||
var_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a type to the unifier.
|
||||
/// Returns a key in the unification_table.
|
||||
pub fn add_ty(&self, a: TypeEnum) -> Type {
|
||||
self.unification_table
|
||||
.borrow_mut()
|
||||
.new_key(TypeCell(Rc::new(a.into())))
|
||||
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
|
||||
self.unification_table.new_key(TypeCell(Rc::new(a.into())))
|
||||
}
|
||||
|
||||
/// Get the TypeEnum of a type.
|
||||
pub fn get_ty(&self, a: Type) -> Rc<RefCell<TypeEnum>> {
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
table.probe_value(a).0
|
||||
pub fn get_ty(&mut self, a: Type) -> Rc<RefCell<TypeEnum>> {
|
||||
self.unification_table.probe_value(a).0
|
||||
}
|
||||
|
||||
/// Unify two types, i.e. a = b.
|
||||
pub fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> {
|
||||
pub fn unify(&mut self, mut a: Type, mut b: Type) -> Result<(), String> {
|
||||
let (mut ty_a_cell, mut ty_b_cell) = {
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
if table.unioned(a, b) {
|
||||
if self.unification_table.unioned(a, b) {
|
||||
return Ok(());
|
||||
}
|
||||
(table.probe_value(a), table.probe_value(b))
|
||||
(
|
||||
self.unification_table.probe_value(a),
|
||||
self.unification_table.probe_value(b),
|
||||
)
|
||||
};
|
||||
|
||||
let (ty_a, ty_b) = {
|
||||
|
@ -353,7 +357,6 @@ impl Unifier {
|
|||
TypeEnum::TVirtual { ty } => {
|
||||
// not sure if this is correct...
|
||||
self.unify(a, *ty)?;
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
_ => {
|
||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
||||
|
@ -390,14 +393,105 @@ impl Unifier {
|
|||
return self.incompatible_types(&*ty_a, &*ty_b);
|
||||
}
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
TypeEnum::TCall { calls: c1 } => match &*ty_b {
|
||||
TypeEnum::TCall { .. } => {
|
||||
drop(ty_b);
|
||||
if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
|
||||
c2.extend(c1.iter().cloned());
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
TypeEnum::TFunc(signature) => {
|
||||
let required: Vec<String> = signature
|
||||
.args
|
||||
.iter()
|
||||
.filter(|v| !v.is_optional)
|
||||
.map(|v| v.name.clone())
|
||||
.rev()
|
||||
.collect();
|
||||
for c in c1 {
|
||||
let Call {
|
||||
posargs,
|
||||
kwargs,
|
||||
ret,
|
||||
fun,
|
||||
} = c.as_ref();
|
||||
let instantiated = self.instantiate_fun(b, signature);
|
||||
let signature;
|
||||
let r = self.get_ty(instantiated);
|
||||
let r = r.as_ref().borrow();
|
||||
if let TypeEnum::TFunc(s) = &*r {
|
||||
signature = s;
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
let mut required = required.clone();
|
||||
let mut all_names: Vec<_> = signature
|
||||
.args
|
||||
.iter()
|
||||
.map(|v| (v.name.clone(), v.ty))
|
||||
.rev()
|
||||
.collect();
|
||||
for (i, t) in posargs.iter().enumerate() {
|
||||
if signature.args.len() <= i {
|
||||
return Err(format!("Too many arguments."));
|
||||
}
|
||||
if !required.is_empty() {
|
||||
required.pop();
|
||||
}
|
||||
self.unify(all_names.pop().unwrap().1, *t)?;
|
||||
}
|
||||
for (k, t) in kwargs.iter() {
|
||||
if let Some(i) = required.iter().position(|v| v == k) {
|
||||
required.remove(i);
|
||||
}
|
||||
if let Some(i) = all_names.iter().position(|v| &v.0 == k) {
|
||||
self.unify(all_names.remove(i).1, *t)?;
|
||||
} else {
|
||||
return Err(format!("Unknown keyword argument {}", k));
|
||||
}
|
||||
}
|
||||
self.unify(*ret, signature.ret)?;
|
||||
*fun.borrow_mut() = Some(instantiated);
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
_ => {
|
||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
||||
}
|
||||
},
|
||||
TypeEnum::TFunc(sign1) => {
|
||||
if let TypeEnum::TFunc(sign2) = &*ty_b {
|
||||
if !sign1.params.is_empty() || !sign2.params.is_empty() {
|
||||
return Err(format!("Polymorphic function pointer is prohibited."));
|
||||
}
|
||||
if sign1.args.len() != sign2.args.len() {
|
||||
return Err(format!("Functions differ in number of parameters."));
|
||||
}
|
||||
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
|
||||
if x.name != y.name {
|
||||
return Err(format!("Functions differ in parameter names."));
|
||||
}
|
||||
if x.is_optional != y.is_optional {
|
||||
return Err(format!("Functions differ in optional parameters."));
|
||||
}
|
||||
self.unify(x.ty, y.ty)?;
|
||||
}
|
||||
self.unify(sign1.ret, sign2.ret)?;
|
||||
self.set_a_to_b(a, b);
|
||||
} else {
|
||||
return self.incompatible_types(&*ty_a, &*ty_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_a_to_b(&self, a: Type, b: Type) {
|
||||
fn set_a_to_b(&mut self, a: Type, b: Type) {
|
||||
// unify a and b together, and set the value to b's value.
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
let table = &mut self.unification_table;
|
||||
let ty_b = table.probe_value(b);
|
||||
table.union(a, b);
|
||||
table.union_value(a, ty_b);
|
||||
|
@ -411,11 +505,11 @@ impl Unifier {
|
|||
))
|
||||
}
|
||||
|
||||
fn occur_check(&self, a: Type, b: Type) -> Result<(), String> {
|
||||
if self.unification_table.borrow_mut().unioned(a, b) {
|
||||
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||
if self.unification_table.unioned(a, b) {
|
||||
return Err("Recursive type is prohibited.".to_owned());
|
||||
}
|
||||
let ty = self.unification_table.borrow_mut().probe_value(b);
|
||||
let ty = self.unification_table.probe_value(b);
|
||||
let ty = ty.borrow();
|
||||
|
||||
match &*ty {
|
||||
|
@ -454,7 +548,7 @@ impl Unifier {
|
|||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TFunc { args, ret, params } => {
|
||||
TypeEnum::TFunc(FunSignature { args, ret, params }) => {
|
||||
for t in args
|
||||
.iter()
|
||||
.map(|v| &v.ty)
|
||||
|
@ -472,8 +566,8 @@ impl Unifier {
|
|||
/// If this returns Some(T), T would be the substituted type.
|
||||
/// If this returns None, the result type would be the original type
|
||||
/// (no substitution has to be done).
|
||||
pub fn subst(&self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
let ty_cell = self.unification_table.borrow_mut().probe_value(a);
|
||||
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
let ty_cell = self.unification_table.probe_value(a);
|
||||
let ty = ty_cell.borrow();
|
||||
// this function would only be called when we instantiate functions.
|
||||
// function type signature should ONLY contain concrete types and type
|
||||
|
@ -512,10 +606,10 @@ impl Unifier {
|
|||
// parameter list, we don't need to substitute the fields.
|
||||
// This is also used to prevent infinite substitution...
|
||||
let need_subst = params.values().any(|v| {
|
||||
let ty_cell = self.unification_table.borrow_mut().probe_value(*v);
|
||||
let ty_cell = self.unification_table.probe_value(*v);
|
||||
let ty = ty_cell.borrow();
|
||||
if let TypeEnum::TVar { id } = &*ty {
|
||||
mapping.contains_key(id)
|
||||
mapping.contains_key(&id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
@ -537,7 +631,7 @@ impl Unifier {
|
|||
None
|
||||
}
|
||||
}
|
||||
TypeEnum::TFunc { args, ret, params } => {
|
||||
TypeEnum::TFunc(FunSignature { args, ret, params }) => {
|
||||
let new_params = self.subst_map(params, mapping);
|
||||
let new_ret = self.subst(*ret, mapping);
|
||||
let mut new_args = None;
|
||||
|
@ -557,7 +651,7 @@ impl Unifier {
|
|||
let params = new_params.unwrap_or_else(|| params.clone());
|
||||
let ret = new_ret.unwrap_or_else(|| *ret);
|
||||
let args = new_args.unwrap_or_else(|| args.clone());
|
||||
Some(self.add_ty(TypeEnum::TFunc { params, ret, args }))
|
||||
Some(self.add_ty(TypeEnum::TFunc(FunSignature { params, ret, args })))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -566,7 +660,7 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
|
||||
fn subst_map<K>(&self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
|
||||
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
|
@ -582,13 +676,43 @@ impl Unifier {
|
|||
map2
|
||||
}
|
||||
|
||||
/// Instantiate a function if it hasn't been instntiated.
|
||||
/// Returns Some(T) where T is the instantiated type.
|
||||
/// Returns None if the function is already instantiated.
|
||||
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
|
||||
let mut instantiated = false;
|
||||
for (k, v) in fun.params.iter() {
|
||||
if let TypeEnum::TVar { id } =
|
||||
&*self.unification_table.probe_value(*v).as_ref().borrow()
|
||||
{
|
||||
if k != id {
|
||||
instantiated = true;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
instantiated = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if instantiated {
|
||||
ty
|
||||
} else {
|
||||
let mapping = fun
|
||||
.params
|
||||
.iter()
|
||||
.map(|(k, _)| (*k, self.get_fresh_var().0))
|
||||
.collect();
|
||||
self.subst(ty, &mapping).unwrap_or(ty)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether two types are equal.
|
||||
pub fn eq(&self, a: Type, b: Type) -> bool {
|
||||
pub fn eq(&mut self, a: Type, b: Type) -> bool {
|
||||
if a == b {
|
||||
return true;
|
||||
}
|
||||
let (ty_a, ty_b) = {
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
let table = &mut self.unification_table;
|
||||
if table.unioned(a, b) {
|
||||
return true;
|
||||
}
|
||||
|
@ -629,7 +753,7 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
|
||||
fn map_eq<K>(&self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
|
||||
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
|
@ -643,4 +767,11 @@ impl Unifier {
|
|||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Get a fresh type variable.
|
||||
pub fn get_fresh_var(&mut self) -> (Type, u32) {
|
||||
let id = self.var_id + 1;
|
||||
self.var_id += 1;
|
||||
(self.add_ty(TypeEnum::TVar { id }), id)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue