function unification...

This commit is contained in:
pca006132 2021-07-16 15:55:52 +08:00
parent f4121b570d
commit d67407716c
2 changed files with 177 additions and 56 deletions

View File

@ -8,14 +8,12 @@ mod test {
struct TestEnvironment {
pub unifier: Unifier,
type_mapping: HashMap<String, Type>,
var_max_id: u32,
}
impl TestEnvironment {
fn new() -> TestEnvironment {
let unifier = Unifier::new();
let mut unifier = Unifier::new();
let mut type_mapping = HashMap::new();
let mut var_max_id = 0;
type_mapping.insert(
"int".into(),
@ -41,38 +39,30 @@ mod test {
params: HashMap::new(),
}),
);
let v0 = unifier.add_ty(TypeEnum::TVar { id: 0 });
var_max_id += 1;
let (v0, id) = unifier.get_fresh_var();
type_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: 3,
fields: [("a".into(), v0)].iter().cloned().collect(),
params: [(0u32, v0)].iter().cloned().collect(),
params: [(id, v0)].iter().cloned().collect(),
}),
);
TestEnvironment {
unifier,
type_mapping,
var_max_id,
}
}
fn get_fresh_var(&mut self) -> Type {
let id = self.var_max_id + 1;
self.var_max_id += 1;
self.unifier.add_ty(TypeEnum::TVar { id })
}
fn parse(&self, typ: &str, mapping: &Mapping<String>) -> Type {
fn parse(&mut self, typ: &str, mapping: &Mapping<String>) -> Type {
let result = self.internal_parse(typ, mapping);
assert!(result.1.is_empty());
result.0
}
fn internal_parse<'a, 'b>(
&'a self,
&'a mut self,
typ: &'b str,
mapping: &Mapping<String>,
) -> (Type, &'b str) {
@ -189,8 +179,8 @@ mod test {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.get_fresh_var();
mapping.insert(format!("v{}", i), v);
let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
@ -259,8 +249,8 @@ mod test {
let mut env = TestEnvironment::new();
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.get_fresh_var();
mapping.insert(format!("v{}", i), v);
let v = env.unifier.get_fresh_var();
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.

View File

@ -48,7 +48,7 @@ impl Deref for TypeCell {
}
pub type Mapping<K, V = Type> = HashMap<K, V>;
pub type VarMap = Mapping<u32>;
type VarMap = Mapping<u32>;
#[derive(Clone)]
pub struct Call {
@ -65,6 +65,13 @@ pub struct FuncArg {
is_optional: bool,
}
#[derive(Clone)]
pub struct FunSignature {
args: Vec<FuncArg>,
ret: Type,
params: VarMap,
}
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
// We may not really need so much `Rc`s, but we would have to do complicated
// stuffs otherwise.
@ -96,11 +103,7 @@ pub enum TypeEnum {
TCall {
calls: Vec<Rc<Call>>,
},
TFunc {
args: Vec<FuncArg>,
ret: Type,
params: VarMap,
},
TFunc(FunSignature),
}
// Order:
@ -199,40 +202,41 @@ pub struct ObjDef {
}
pub struct Unifier {
unification_table: RefCell<InPlaceUnificationTable<Type>>,
unification_table: InPlaceUnificationTable<Type>,
obj_def_table: Vec<ObjDef>,
var_id: u32,
}
impl Unifier {
pub fn new() -> Unifier {
Unifier {
unification_table: RefCell::new(InPlaceUnificationTable::new()),
unification_table: InPlaceUnificationTable::new(),
obj_def_table: Vec::new(),
var_id: 0,
}
}
/// Register a type to the unifier.
/// Returns a key in the unification_table.
pub fn add_ty(&self, a: TypeEnum) -> Type {
self.unification_table
.borrow_mut()
.new_key(TypeCell(Rc::new(a.into())))
pub fn add_ty(&mut self, a: TypeEnum) -> Type {
self.unification_table.new_key(TypeCell(Rc::new(a.into())))
}
/// Get the TypeEnum of a type.
pub fn get_ty(&self, a: Type) -> Rc<RefCell<TypeEnum>> {
let mut table = self.unification_table.borrow_mut();
table.probe_value(a).0
pub fn get_ty(&mut self, a: Type) -> Rc<RefCell<TypeEnum>> {
self.unification_table.probe_value(a).0
}
/// Unify two types, i.e. a = b.
pub fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> {
pub fn unify(&mut self, mut a: Type, mut b: Type) -> Result<(), String> {
let (mut ty_a_cell, mut ty_b_cell) = {
let mut table = self.unification_table.borrow_mut();
if table.unioned(a, b) {
if self.unification_table.unioned(a, b) {
return Ok(());
}
(table.probe_value(a), table.probe_value(b))
(
self.unification_table.probe_value(a),
self.unification_table.probe_value(b),
)
};
let (ty_a, ty_b) = {
@ -353,7 +357,6 @@ impl Unifier {
TypeEnum::TVirtual { ty } => {
// not sure if this is correct...
self.unify(a, *ty)?;
self.set_a_to_b(a, b);
}
_ => {
return self.incompatible_types(&*ty_a, &*ty_b);
@ -390,14 +393,105 @@ impl Unifier {
return self.incompatible_types(&*ty_a, &*ty_b);
}
}
_ => unimplemented!(),
TypeEnum::TCall { calls: c1 } => match &*ty_b {
TypeEnum::TCall { .. } => {
drop(ty_b);
if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() {
c2.extend(c1.iter().cloned());
} else {
unreachable!()
}
self.set_a_to_b(a, b);
}
TypeEnum::TFunc(signature) => {
let required: Vec<String> = signature
.args
.iter()
.filter(|v| !v.is_optional)
.map(|v| v.name.clone())
.rev()
.collect();
for c in c1 {
let Call {
posargs,
kwargs,
ret,
fun,
} = c.as_ref();
let instantiated = self.instantiate_fun(b, signature);
let signature;
let r = self.get_ty(instantiated);
let r = r.as_ref().borrow();
if let TypeEnum::TFunc(s) = &*r {
signature = s;
} else {
unreachable!();
}
let mut required = required.clone();
let mut all_names: Vec<_> = signature
.args
.iter()
.map(|v| (v.name.clone(), v.ty))
.rev()
.collect();
for (i, t) in posargs.iter().enumerate() {
if signature.args.len() <= i {
return Err(format!("Too many arguments."));
}
if !required.is_empty() {
required.pop();
}
self.unify(all_names.pop().unwrap().1, *t)?;
}
for (k, t) in kwargs.iter() {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
if let Some(i) = all_names.iter().position(|v| &v.0 == k) {
self.unify(all_names.remove(i).1, *t)?;
} else {
return Err(format!("Unknown keyword argument {}", k));
}
}
self.unify(*ret, signature.ret)?;
*fun.borrow_mut() = Some(instantiated);
}
self.set_a_to_b(a, b);
}
_ => {
return self.incompatible_types(&*ty_a, &*ty_b);
}
},
TypeEnum::TFunc(sign1) => {
if let TypeEnum::TFunc(sign2) = &*ty_b {
if !sign1.params.is_empty() || !sign2.params.is_empty() {
return Err(format!("Polymorphic function pointer is prohibited."));
}
if sign1.args.len() != sign2.args.len() {
return Err(format!("Functions differ in number of parameters."));
}
for (x, y) in sign1.args.iter().zip(sign2.args.iter()) {
if x.name != y.name {
return Err(format!("Functions differ in parameter names."));
}
if x.is_optional != y.is_optional {
return Err(format!("Functions differ in optional parameters."));
}
self.unify(x.ty, y.ty)?;
}
self.unify(sign1.ret, sign2.ret)?;
self.set_a_to_b(a, b);
} else {
return self.incompatible_types(&*ty_a, &*ty_b);
}
}
}
Ok(())
}
fn set_a_to_b(&self, a: Type, b: Type) {
fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value.
let mut table = self.unification_table.borrow_mut();
let table = &mut self.unification_table;
let ty_b = table.probe_value(b);
table.union(a, b);
table.union_value(a, ty_b);
@ -411,11 +505,11 @@ impl Unifier {
))
}
fn occur_check(&self, a: Type, b: Type) -> Result<(), String> {
if self.unification_table.borrow_mut().unioned(a, b) {
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
if self.unification_table.unioned(a, b) {
return Err("Recursive type is prohibited.".to_owned());
}
let ty = self.unification_table.borrow_mut().probe_value(b);
let ty = self.unification_table.probe_value(b);
let ty = ty.borrow();
match &*ty {
@ -454,7 +548,7 @@ impl Unifier {
self.occur_check(a, *t)?;
}
}
TypeEnum::TFunc { args, ret, params } => {
TypeEnum::TFunc(FunSignature { args, ret, params }) => {
for t in args
.iter()
.map(|v| &v.ty)
@ -472,8 +566,8 @@ impl Unifier {
/// If this returns Some(T), T would be the substituted type.
/// If this returns None, the result type would be the original type
/// (no substitution has to be done).
pub fn subst(&self, a: Type, mapping: &VarMap) -> Option<Type> {
let ty_cell = self.unification_table.borrow_mut().probe_value(a);
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
let ty_cell = self.unification_table.probe_value(a);
let ty = ty_cell.borrow();
// this function would only be called when we instantiate functions.
// function type signature should ONLY contain concrete types and type
@ -512,10 +606,10 @@ impl Unifier {
// parameter list, we don't need to substitute the fields.
// This is also used to prevent infinite substitution...
let need_subst = params.values().any(|v| {
let ty_cell = self.unification_table.borrow_mut().probe_value(*v);
let ty_cell = self.unification_table.probe_value(*v);
let ty = ty_cell.borrow();
if let TypeEnum::TVar { id } = &*ty {
mapping.contains_key(id)
mapping.contains_key(&id)
} else {
false
}
@ -537,7 +631,7 @@ impl Unifier {
None
}
}
TypeEnum::TFunc { args, ret, params } => {
TypeEnum::TFunc(FunSignature { args, ret, params }) => {
let new_params = self.subst_map(params, mapping);
let new_ret = self.subst(*ret, mapping);
let mut new_args = None;
@ -557,7 +651,7 @@ impl Unifier {
let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or_else(|| *ret);
let args = new_args.unwrap_or_else(|| args.clone());
Some(self.add_ty(TypeEnum::TFunc { params, ret, args }))
Some(self.add_ty(TypeEnum::TFunc(FunSignature { params, ret, args })))
} else {
None
}
@ -566,7 +660,7 @@ impl Unifier {
}
}
fn subst_map<K>(&self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
@ -582,13 +676,43 @@ impl Unifier {
map2
}
/// Instantiate a function if it hasn't been instntiated.
/// Returns Some(T) where T is the instantiated type.
/// Returns None if the function is already instantiated.
fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type {
let mut instantiated = false;
for (k, v) in fun.params.iter() {
if let TypeEnum::TVar { id } =
&*self.unification_table.probe_value(*v).as_ref().borrow()
{
if k != id {
instantiated = true;
break;
}
} else {
instantiated = true;
break;
}
}
if instantiated {
ty
} else {
let mapping = fun
.params
.iter()
.map(|(k, _)| (*k, self.get_fresh_var().0))
.collect();
self.subst(ty, &mapping).unwrap_or(ty)
}
}
/// Check whether two types are equal.
pub fn eq(&self, a: Type, b: Type) -> bool {
pub fn eq(&mut self, a: Type, b: Type) -> bool {
if a == b {
return true;
}
let (ty_a, ty_b) = {
let mut table = self.unification_table.borrow_mut();
let table = &mut self.unification_table;
if table.unioned(a, b) {
return true;
}
@ -629,7 +753,7 @@ impl Unifier {
}
}
fn map_eq<K>(&self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{
@ -643,4 +767,11 @@ impl Unifier {
}
true
}
/// Get a fresh type variable.
pub fn get_fresh_var(&mut self) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TVar { id }), id)
}
}