diff --git a/Cargo.lock b/Cargo.lock index d08e77c0..d4561e97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -384,7 +384,6 @@ checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" name = "nac3core" version = "0.1.0" dependencies = [ - "ena", "indoc 1.0.3", "inkwell", "itertools", diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 402ce18b..5d3753f3 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -9,7 +9,6 @@ num-bigint = "0.3" num-traits = "0.2" inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm10-0"] } rustpython-parser = { git = "https://github.com/RustPython/RustPython", branch = "master" } -ena = "0.14" itertools = "0.10.1" [dev-dependencies] diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index 6e50fadf..6b55ddba 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -4,3 +4,4 @@ mod magic_methods; pub mod symbol_resolver; pub mod typedef; pub mod type_inferencer; +mod unification_table; diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 680bcc67..9a355c33 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,55 +1,20 @@ -use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; use itertools::Itertools; use std::cell::RefCell; use std::collections::HashMap; -use std::fmt::Debug; use std::iter::once; -use std::ops::Deref; use std::rc::Rc; +use super::unification_table::{UnificationKey, UnificationTable}; + #[cfg(test)] mod test; -#[derive(Copy, Clone, PartialEq, Eq, Debug)] /// Handle for a type, implementated as a key in the unification table. -pub struct Type(u32); +pub type Type = UnificationKey; #[derive(Clone)] pub struct TypeCell(Rc>); -impl UnifyValue for TypeCell { - type Error = NoError; - fn unify_values(_: &Self, value2: &Self) -> Result { - // WARN: depends on the implementation details of ena. - // We do not use this to do unification, instead we perform unification - // and assign the type by `union_value(key, new_value)`, which set the - // value as `unify_values(key.value, new_value)`. So, we need to return - // the right one. - Ok(value2.clone()) - } -} - -impl UnifyKey for Type { - type Value = TypeCell; - fn index(&self) -> u32 { - self.0 - } - fn from_index(u: u32) -> Self { - Type(u) - } - fn tag() -> &'static str { - "TypeID" - } -} - -impl Deref for TypeCell { - type Target = Rc>; - - fn deref(&self) -> &::Target { - &self.0 - } -} - pub type Mapping = HashMap; type VarMap = Mapping; @@ -78,6 +43,7 @@ pub struct FunSignature { // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. // We may not really need so much `Rc`s, but we would have to do complicated // stuffs otherwise. +#[derive(Clone)] pub enum TypeEnum { TVar { // TODO: upper/lower bound @@ -138,14 +104,8 @@ impl TypeEnum { } } -impl Debug for TypeCell { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.borrow().get_type_name()) - } -} - pub struct Unifier { - unification_table: InPlaceUnificationTable, + unification_table: UnificationTable>>, var_id: u32, } @@ -153,7 +113,7 @@ impl Unifier { /// Get an empty unifier pub fn new() -> Unifier { Unifier { - unification_table: InPlaceUnificationTable::new(), + unification_table: UnificationTable::new(), var_id: 0, } } @@ -161,12 +121,12 @@ impl Unifier { /// Register a type to the unifier. /// Returns a key in the unification_table. pub fn add_ty(&mut self, a: TypeEnum) -> Type { - self.unification_table.new_key(TypeCell(Rc::new(a.into()))) + self.unification_table.new_key(Rc::new(a.into())) } /// Get the TypeEnum of a type. pub fn get_ty(&mut self, a: Type) -> Rc> { - self.unification_table.probe_value(a).0 + self.unification_table.probe_value(a).clone() } /// Unify two types, i.e. a = b. @@ -187,7 +147,7 @@ impl Unifier { F: FnMut(usize) -> String, G: FnMut(u32) -> String, { - let ty = self.unification_table.probe_value(ty).0; + let ty = self.unification_table.probe_value(ty).clone(); let ty = ty.as_ref().borrow(); match &*ty { TypeEnum::TVar { id } => var_to_name(*id), @@ -252,8 +212,8 @@ impl Unifier { return Ok(()); } ( - self.unification_table.probe_value(a), - self.unification_table.probe_value(b), + self.unification_table.probe_value(a).clone(), + self.unification_table.probe_value(b).clone(), ) }; @@ -484,9 +444,9 @@ impl Unifier { fn set_a_to_b(&mut self, a: Type, b: Type) { // unify a and b together, and set the value to b's value. let table = &mut self.unification_table; - let ty_b = table.probe_value(b); - table.union(a, b); - table.union_value(a, ty_b); + let ty_b = table.probe_value(b).clone(); + table.unify(a, b); + table.set_value(a, ty_b) } fn incompatible_types(&self, a: &TypeEnum, b: &TypeEnum) -> Result<(), String> { @@ -501,7 +461,7 @@ impl Unifier { if self.unification_table.unioned(a, b) { return Err("Recursive type is prohibited.".to_owned()); } - let ty = self.unification_table.probe_value(b); + let ty = self.unification_table.probe_value(b).clone(); let ty = ty.borrow(); match &*ty { @@ -568,7 +528,7 @@ impl Unifier { /// If this returns None, the result type would be the original type /// (no substitution has to be done). fn subst(&mut self, a: Type, mapping: &VarMap) -> Option { - let ty_cell = self.unification_table.probe_value(a); + let ty_cell = self.unification_table.probe_value(a).clone(); let ty = ty_cell.borrow(); // this function would only be called when we instantiate functions. // function type signature should ONLY contain concrete types and type @@ -725,7 +685,7 @@ impl Unifier { if table.unioned(a, b) { return true; } - (table.probe_value(a), table.probe_value(b)) + (table.probe_value(a).clone(), table.probe_value(b).clone()) }; let ty_a = ty_a.borrow(); diff --git a/nac3core/src/typecheck/unification_table.rs b/nac3core/src/typecheck/unification_table.rs new file mode 100644 index 00000000..60ec8086 --- /dev/null +++ b/nac3core/src/typecheck/unification_table.rs @@ -0,0 +1,104 @@ +use std::cell::RefCell; +use std::rc::Rc; + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct UnificationKey(usize); + +pub struct UnificationTable { + parents: Vec, + ranks: Vec, + values: Vec, +} + +impl UnificationTable { + pub fn new() -> UnificationTable { + UnificationTable { + parents: Vec::new(), + ranks: Vec::new(), + values: Vec::new(), + } + } + + pub fn new_key(&mut self, v: V) -> UnificationKey { + let index = self.parents.len(); + self.parents.push(index); + self.ranks.push(0); + self.values.push(v); + UnificationKey(index) + } + + pub fn unify(&mut self, a: UnificationKey, b: UnificationKey) { + let mut a = self.find(a); + let mut b = self.find(b); + if a == b { + return; + } + if self.ranks[a] < self.ranks[b] { + std::mem::swap(&mut a, &mut b); + } + self.parents[b] = a; + if self.ranks[a] == self.ranks[b] { + self.ranks[a] += 1; + } + } + + pub fn probe_value(&mut self, a: UnificationKey) -> &V { + let index = self.find(a); + &self.values[index] + } + + pub fn set_value(&mut self, a: UnificationKey, v: V) { + let index = self.find(a); + self.values[index] = v; + } + + pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { + self.find(a) == self.find(b) + } + + fn find(&mut self, key: UnificationKey) -> usize { + let mut root = key.0; + let mut parent = self.parents[root]; + while root != parent { + // a = parent.parent + let a = self.parents[parent]; + // root.parent = parent.parent + self.parents[root] = a; + root = parent; + // parent = root.parent + parent = a; + } + parent + } +} + +impl UnificationTable>> +where + V: Clone, +{ + pub fn into_send(self) -> UnificationTable { + let values = self + .values + .iter() + .map(|v| v.as_ref().borrow().clone()) + .collect(); + UnificationTable { + parents: self.parents, + ranks: self.ranks, + values, + } + } + + pub fn from_send(table: UnificationTable) -> UnificationTable>> { + let values = table + .values + .into_iter() + .map(|v| Rc::new(RefCell::new(v))) + .collect(); + UnificationTable { + parents: table.parents, + ranks: table.ranks, + values, + } + } +}