hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
5 changed files with 122 additions and 59 deletions
Showing only changes of commit 09c9218852 - Show all commits

1
Cargo.lock generated
View File

@ -384,7 +384,6 @@ checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc"
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"ena",
"indoc 1.0.3", "indoc 1.0.3",
"inkwell", "inkwell",
"itertools", "itertools",

View File

@ -9,7 +9,6 @@ num-bigint = "0.3"
num-traits = "0.2" num-traits = "0.2"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] } inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] }
rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" } rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" }
ena = "0.14"
itertools = "0.10.1" itertools = "0.10.1"
[dev-dependencies] [dev-dependencies]

View File

@ -4,3 +4,4 @@ mod magic_methods;
pub mod symbol_resolver; pub mod symbol_resolver;
pub mod typedef; pub mod typedef;
pub mod type_inferencer; pub mod type_inferencer;
mod unification_table;

View File

@ -1,55 +1,20 @@
use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue};
use itertools::Itertools; use itertools::Itertools;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::once; use std::iter::once;
use std::ops::Deref;
use std::rc::Rc; use std::rc::Rc;
use super::unification_table::{UnificationKey, UnificationTable};
#[cfg(test)] #[cfg(test)]
mod test; mod test;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
/// Handle for a type, implementated as a key in the unification table. /// Handle for a type, implementated as a key in the unification table.
pub struct Type(u32); pub type Type = UnificationKey;
#[derive(Clone)] #[derive(Clone)]
pub struct TypeCell(Rc<RefCell<TypeEnum>>); pub struct TypeCell(Rc<RefCell<TypeEnum>>);
impl UnifyValue for TypeCell {
type Error = NoError;
fn unify_values(_: &Self, value2: &Self) -> Result<Self, Self::Error> {
// WARN: depends on the implementation details of ena.
// We do not use this to do unification, instead we perform unification
// and assign the type by `union_value(key, new_value)`, which set the
// value as `unify_values(key.value, new_value)`. So, we need to return
// the right one.
Ok(value2.clone())
}
}
impl UnifyKey for Type {
type Value = TypeCell;
fn index(&self) -> u32 {
self.0
}
fn from_index(u: u32) -> Self {
Type(u)
}
fn tag() -> &'static str {
"TypeID"
}
}
impl Deref for TypeCell {
type Target = Rc<RefCell<TypeEnum>>;
fn deref(&self) -> &<Self as Deref>::Target {
&self.0
}
}
pub type Mapping<K, V = Type> = HashMap<K, V>; pub type Mapping<K, V = Type> = HashMap<K, V>;
type VarMap = Mapping<u32>; type VarMap = Mapping<u32>;
@ -78,6 +43,7 @@ pub struct FunSignature {
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
// We may not really need so much `Rc`s, but we would have to do complicated // We may not really need so much `Rc`s, but we would have to do complicated
// stuffs otherwise. // stuffs otherwise.
#[derive(Clone)]
pub enum TypeEnum { pub enum TypeEnum {
TVar { TVar {
// TODO: upper/lower bound // TODO: upper/lower bound
@ -138,14 +104,8 @@ impl TypeEnum {
} }
} }
impl Debug for TypeCell {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.borrow().get_type_name())
}
}
pub struct Unifier { pub struct Unifier {
unification_table: InPlaceUnificationTable<Type>, unification_table: UnificationTable<Rc<RefCell<TypeEnum>>>,
var_id: u32, var_id: u32,
} }
@ -153,7 +113,7 @@ impl Unifier {
/// Get an empty unifier /// Get an empty unifier
pub fn new() -> Unifier { pub fn new() -> Unifier {
Unifier { Unifier {
unification_table: InPlaceUnificationTable::new(), unification_table: UnificationTable::new(),
var_id: 0, var_id: 0,
} }
} }
@ -161,12 +121,12 @@ impl Unifier {
/// Register a type to the unifier. /// Register a type to the unifier.
/// Returns a key in the unification_table. /// Returns a key in the unification_table.
pub fn add_ty(&mut self, a: TypeEnum) -> Type { pub fn add_ty(&mut self, a: TypeEnum) -> Type {
self.unification_table.new_key(TypeCell(Rc::new(a.into()))) self.unification_table.new_key(Rc::new(a.into()))
} }
/// Get the TypeEnum of a type. /// Get the TypeEnum of a type.
pub fn get_ty(&mut self, a: Type) -> Rc<RefCell<TypeEnum>> { pub fn get_ty(&mut self, a: Type) -> Rc<RefCell<TypeEnum>> {
self.unification_table.probe_value(a).0 self.unification_table.probe_value(a).clone()
} }
/// Unify two types, i.e. a = b. /// Unify two types, i.e. a = b.
@ -187,7 +147,7 @@ impl Unifier {
F: FnMut(usize) -> String, F: FnMut(usize) -> String,
G: FnMut(u32) -> String, G: FnMut(u32) -> String,
{ {
let ty = self.unification_table.probe_value(ty).0; let ty = self.unification_table.probe_value(ty).clone();
let ty = ty.as_ref().borrow(); let ty = ty.as_ref().borrow();
match &*ty { match &*ty {
TypeEnum::TVar { id } => var_to_name(*id), TypeEnum::TVar { id } => var_to_name(*id),
@ -252,8 +212,8 @@ impl Unifier {
return Ok(()); return Ok(());
} }
( (
self.unification_table.probe_value(a), self.unification_table.probe_value(a).clone(),
self.unification_table.probe_value(b), self.unification_table.probe_value(b).clone(),
) )
}; };
@ -484,9 +444,9 @@ impl Unifier {
fn set_a_to_b(&mut self, a: Type, b: Type) { fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value. // unify a and b together, and set the value to b's value.
let table = &mut self.unification_table; let table = &mut self.unification_table;
let ty_b = table.probe_value(b); let ty_b = table.probe_value(b).clone();
table.union(a, b); table.unify(a, b);
table.union_value(a, ty_b); table.set_value(a, ty_b)
} }
fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> { fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> {
@ -501,7 +461,7 @@ impl Unifier {
if self.unification_table.unioned(a, b) { if self.unification_table.unioned(a, b) {
return Err("Recursive type is prohibited.".to_owned()); return Err("Recursive type is prohibited.".to_owned());
} }
let ty = self.unification_table.probe_value(b); let ty = self.unification_table.probe_value(b).clone();
let ty = ty.borrow(); let ty = ty.borrow();
match &*ty { match &*ty {
@ -568,7 +528,7 @@ impl Unifier {
/// If this returns None, the result type would be the original type /// If this returns None, the result type would be the original type
/// (no substitution has to be done). /// (no substitution has to be done).
fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> { fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
let ty_cell = self.unification_table.probe_value(a); let ty_cell = self.unification_table.probe_value(a).clone();
let ty = ty_cell.borrow(); let ty = ty_cell.borrow();
// this function would only be called when we instantiate functions. // this function would only be called when we instantiate functions.
// function type signature should ONLY contain concrete types and type // function type signature should ONLY contain concrete types and type
@ -725,7 +685,7 @@ impl Unifier {
if table.unioned(a, b) { if table.unioned(a, b) {
return true; return true;
} }
(table.probe_value(a), table.probe_value(b)) (table.probe_value(a).clone(), table.probe_value(b).clone())
}; };
let ty_a = ty_a.borrow(); let ty_a = ty_a.borrow();

View File

@ -0,0 +1,104 @@
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub struct UnificationKey(usize);
pub struct UnificationTable<V> {
parents: Vec<usize>,
ranks: Vec<u32>,
values: Vec<V>,
}
impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> {
UnificationTable {
parents: Vec::new(),
ranks: Vec::new(),
values: Vec::new(),
}
}
pub fn new_key(&mut self, v: V) -> UnificationKey {
let index = self.parents.len();
self.parents.push(index);
self.ranks.push(0);
self.values.push(v);
UnificationKey(index)
}
pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) {
let mut a = self.find(a);
let mut b = self.find(b);
if a == b {
return;
}
if self.ranks[a] < self.ranks[b] {
std::mem::swap(&mut a, &mut b);
}
self.parents[b] = a;
if self.ranks[a] == self.ranks[b] {
self.ranks[a] += 1;
}
}
pub fn probe_value(&mut self, a: UnificationKey) -> &V {
let index = self.find(a);
&self.values[index]
}
pub fn set_value(&mut self, a: UnificationKey, v: V) {
let index = self.find(a);
self.values[index] = v;
}
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
self.find(a) == self.find(b)
}
fn find(&mut self, key: UnificationKey) -> usize {
let mut root = key.0;
let mut parent = self.parents[root];
while root != parent {
// a = parent.parent
let a = self.parents[parent];
// root.parent = parent.parent
self.parents[root] = a;
root = parent;
// parent = root.parent
parent = a;
}
parent
}
}
impl<V> UnificationTable<Rc<RefCell<V>>>
where
V: Clone,
{
pub fn into_send(self) -> UnificationTable<V> {
let values = self
.values
.iter()
.map(|v| v.as_ref().borrow().clone())
.collect();
UnificationTable {
parents: self.parents,
ranks: self.ranks,
values,
}
}
pub fn from_send(table: UnificationTable<V>) -> UnificationTable<Rc<RefCell<V>>> {
let values = table
.values
.into_iter()
.map(|v| Rc::new(RefCell::new(v)))
.collect();
UnificationTable {
parents: table.parents,
ranks: table.ranks,
values,
}
}
}