added tests
This commit is contained in:
parent
1df3f4e757
commit
d94f25583b
|
@ -403,9 +403,11 @@ dependencies = [
|
|||
"generational-arena",
|
||||
"indoc 1.0.3",
|
||||
"inkwell",
|
||||
"itertools",
|
||||
"num-bigint 0.3.2",
|
||||
"num-traits",
|
||||
"rustpython-parser",
|
||||
"test-case",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -844,6 +846,19 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "test-case"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b114ece25254e97bf48dd4bfc2a12bad0647adacfe4cae1247a9ca6ad302cec"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiny-keccak"
|
||||
version = "2.0.2"
|
||||
|
|
|
@ -13,3 +13,7 @@ indoc = "1.0"
|
|||
generational-arena = "0.2"
|
||||
ena = "0.14"
|
||||
|
||||
[dev-dependencies]
|
||||
test-case = "1.2.0"
|
||||
itertools = "0.10.1"
|
||||
|
||||
|
|
|
@ -4,4 +4,5 @@
|
|||
// mod magic_methods;
|
||||
// mod primitives;
|
||||
// pub mod symbol_resolver;
|
||||
mod test_typedef;
|
||||
pub mod typedef;
|
||||
|
|
|
@ -0,0 +1,273 @@
|
|||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::super::typedef::*;
|
||||
use itertools::Itertools;
|
||||
use std::collections::HashMap;
|
||||
use test_case::test_case;
|
||||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
type_mapping: HashMap<String, Type>,
|
||||
var_max_id: u32,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
fn new() -> TestEnvironment {
|
||||
let unifier = Unifier::new();
|
||||
let mut type_mapping = HashMap::new();
|
||||
let mut var_max_id = 0;
|
||||
|
||||
type_mapping.insert(
|
||||
"int".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 0,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
}),
|
||||
);
|
||||
type_mapping.insert(
|
||||
"float".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 1,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
}),
|
||||
);
|
||||
type_mapping.insert(
|
||||
"bool".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 2,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
}),
|
||||
);
|
||||
let v0 = unifier.add_ty(TypeEnum::TVar { id: 0 });
|
||||
var_max_id += 1;
|
||||
type_mapping.insert(
|
||||
"Foo".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 3,
|
||||
fields: [("a".into(), v0)].iter().cloned().collect(),
|
||||
params: [(0u32, v0)].iter().cloned().collect(),
|
||||
}),
|
||||
);
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
type_mapping,
|
||||
var_max_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_fresh_var(&mut self) -> Type {
|
||||
let id = self.var_max_id + 1;
|
||||
self.var_max_id += 1;
|
||||
self.unifier.add_ty(TypeEnum::TVar { id })
|
||||
}
|
||||
|
||||
fn parse(&self, typ: &str, mapping: &Mapping<String>) -> Type {
|
||||
let result = self.internal_parse(typ, mapping);
|
||||
assert!(result.1.is_empty());
|
||||
result.0
|
||||
}
|
||||
|
||||
fn internal_parse<'a, 'b>(
|
||||
&'a self,
|
||||
typ: &'b str,
|
||||
mapping: &Mapping<String>,
|
||||
) -> (Type, &'b str) {
|
||||
// for testing only, so we can just panic when the input is malformed
|
||||
let end = typ
|
||||
.find(|c| ['[', ',', ']', '='].contains(&c))
|
||||
.unwrap_or_else(|| typ.len());
|
||||
match &typ[..end] {
|
||||
"Tuple" => {
|
||||
let mut s = &typ[end..];
|
||||
assert!(&s[0..1] == "[");
|
||||
let mut ty = Vec::new();
|
||||
while &s[0..1] != "]" {
|
||||
let result = self.internal_parse(&s[1..], mapping);
|
||||
ty.push(result.0);
|
||||
s = result.1;
|
||||
}
|
||||
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
|
||||
}
|
||||
"List" => {
|
||||
assert!(&typ[end..end + 1] == "[");
|
||||
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
|
||||
assert!(&s[0..1] == "]");
|
||||
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
|
||||
}
|
||||
"Record" => {
|
||||
let mut s = &typ[end..];
|
||||
assert!(&s[0..1] == "[");
|
||||
let mut fields = HashMap::new();
|
||||
while &s[0..1] != "]" {
|
||||
let eq = s.find('=').unwrap();
|
||||
let key = s[1..eq].to_string();
|
||||
let result = self.internal_parse(&s[eq + 1..], mapping);
|
||||
fields.insert(key, result.0);
|
||||
s = result.1;
|
||||
}
|
||||
(self.unifier.add_ty(TypeEnum::TRecord { fields }), &s[1..])
|
||||
}
|
||||
x => {
|
||||
let mut s = &typ[end..];
|
||||
let ty = mapping.get(x).cloned().unwrap_or_else(|| {
|
||||
// mapping should be type variables, type_mapping should be concrete types
|
||||
// we should not resolve the type of type variables.
|
||||
let mut ty = *self.type_mapping.get(x).unwrap();
|
||||
let te = self.unifier.get_ty(ty);
|
||||
if let TypeEnum::TObj { params, .. } = &*te.as_ref().borrow() {
|
||||
if !params.is_empty() {
|
||||
assert!(&s[0..1] == "[");
|
||||
let mut p = Vec::new();
|
||||
while &s[0..1] != "]" {
|
||||
let result = self.internal_parse(&s[1..], mapping);
|
||||
p.push(result.0);
|
||||
s = result.1;
|
||||
}
|
||||
s = &s[1..];
|
||||
ty = self
|
||||
.unifier
|
||||
.subst(ty, ¶ms.keys().cloned().zip(p.into_iter()).collect())
|
||||
.unwrap_or(ty);
|
||||
}
|
||||
}
|
||||
ty
|
||||
});
|
||||
(ty, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(2,
|
||||
&[("v1", "v2"), ("v2", "float")],
|
||||
&[("v1", "float"), ("v2", "float")]
|
||||
; "simple variable"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[("v1", "List[v2]"), ("v1", "List[float]")],
|
||||
&[("v1", "List[float]"), ("v2", "float")]
|
||||
; "list element"
|
||||
)]
|
||||
#[test_case(3,
|
||||
&[
|
||||
("v1", "Record[a=v3,b=v3]"),
|
||||
("v2", "Record[b=float,c=v3]"),
|
||||
("v1", "v2")
|
||||
],
|
||||
&[
|
||||
("v1", "Record[a=float,b=float,c=float]"),
|
||||
("v2", "Record[a=float,b=float,c=float]"),
|
||||
("v3", "float")
|
||||
]
|
||||
; "record merge"
|
||||
)]
|
||||
#[test_case(3,
|
||||
&[
|
||||
("v1", "Record[a=float]"),
|
||||
("v2", "Foo[v3]"),
|
||||
("v1", "v2")
|
||||
],
|
||||
&[
|
||||
("v1", "Foo[float]"),
|
||||
("v3", "float")
|
||||
]
|
||||
; "record obj merge"
|
||||
)]
|
||||
fn test_unify(
|
||||
variable_count: u32,
|
||||
unify_pairs: &[(&'static str, &'static str)],
|
||||
verify_pairs: &[(&'static str, &'static str)],
|
||||
) {
|
||||
let unify_count = unify_pairs.len();
|
||||
// test all permutations...
|
||||
for perm in unify_pairs.iter().permutations(unify_count) {
|
||||
let mut env = TestEnvironment::new();
|
||||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
let mut pairs = Vec::new();
|
||||
for (a, b) in perm.iter() {
|
||||
let t1 = env.parse(a, &mapping);
|
||||
let t2 = env.parse(b, &mapping);
|
||||
pairs.push((t1, t2));
|
||||
}
|
||||
for (t1, t2) in pairs {
|
||||
env.unifier.unify(t1, t2).unwrap();
|
||||
}
|
||||
for (a, b) in verify_pairs.iter() {
|
||||
let t1 = env.parse(a, &mapping);
|
||||
let t2 = env.parse(b, &mapping);
|
||||
assert!(env.unifier.eq(t1, t2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "Tuple[int]"),
|
||||
("v2", "List[int]"),
|
||||
],
|
||||
(("v1", "v2"), "Cannot unify TTuple with TList")
|
||||
; "kind mismatch"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "Tuple[int]"),
|
||||
("v2", "Tuple[float]"),
|
||||
],
|
||||
(("v1", "v2"), "Cannot unify objects with ID 0 and 1")
|
||||
; "tuple parameter mismatch"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "Tuple[int,int]"),
|
||||
("v2", "Tuple[int]"),
|
||||
],
|
||||
(("v1", "v2"), "Cannot unify tuples with length 1 and 2")
|
||||
; "tuple length mismatch"
|
||||
)]
|
||||
#[test_case(3,
|
||||
&[
|
||||
("v1", "Record[a=float,b=int]"),
|
||||
("v2", "Foo[v3]"),
|
||||
],
|
||||
(("v1", "v2"), "No such attribute b")
|
||||
; "record obj merge"
|
||||
)]
|
||||
fn test_invalid_unification(
|
||||
variable_count: u32,
|
||||
unify_pairs: &[(&'static str, &'static str)],
|
||||
errornous_pair: ((&'static str, &'static str), &'static str),
|
||||
) {
|
||||
let mut env = TestEnvironment::new();
|
||||
let mut mapping = HashMap::new();
|
||||
for i in 1..=variable_count {
|
||||
let v = env.get_fresh_var();
|
||||
mapping.insert(format!("v{}", i), v);
|
||||
}
|
||||
// unification may have side effect when we do type resolution, so freeze the types
|
||||
// before doing unification.
|
||||
let mut pairs = Vec::new();
|
||||
for (a, b) in unify_pairs.iter() {
|
||||
let t1 = env.parse(a, &mapping);
|
||||
let t2 = env.parse(b, &mapping);
|
||||
pairs.push((t1, t2));
|
||||
}
|
||||
let (t1, t2) = (
|
||||
env.parse(errornous_pair.0 .0, &mapping),
|
||||
env.parse(errornous_pair.0 .1, &mapping),
|
||||
);
|
||||
for (a, b) in pairs {
|
||||
env.unifier.unify(a, b).unwrap();
|
||||
}
|
||||
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue};
|
||||
use generational_arena::{Arena, Index};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::iter::once;
|
||||
use std::mem::swap;
|
||||
use std::rc::Rc;
|
||||
|
@ -18,10 +18,10 @@ use std::rc::Rc;
|
|||
// `--> TFunc
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
||||
struct Type(u32);
|
||||
pub struct Type(u32);
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
struct TypeIndex(Index);
|
||||
pub struct TypeIndex(Index);
|
||||
|
||||
impl UnifyValue for TypeIndex {
|
||||
type Error = NoError;
|
||||
|
@ -48,19 +48,19 @@ impl UnifyKey for Type {
|
|||
}
|
||||
}
|
||||
|
||||
type Mapping<K, V = Type> = BTreeMap<K, V>;
|
||||
type VarMap = Mapping<u32>;
|
||||
pub type Mapping<K, V = Type> = HashMap<K, V>;
|
||||
pub type VarMap = Mapping<u32>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Call {
|
||||
pub struct Call {
|
||||
posargs: Vec<Type>,
|
||||
kwargs: BTreeMap<String, Type>,
|
||||
kwargs: HashMap<String, Type>,
|
||||
ret: Type,
|
||||
fn_id: usize,
|
||||
fun: RefCell<Option<Type>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FuncArg {
|
||||
pub struct FuncArg {
|
||||
name: String,
|
||||
ty: Type,
|
||||
is_optional: bool,
|
||||
|
@ -69,7 +69,7 @@ struct FuncArg {
|
|||
// We use a lot of `Rc`/`RefCell`s here as we want to simplify our code.
|
||||
// We may not really need so much `Rc`s, but we would have to do complicated
|
||||
// stuffs otherwise.
|
||||
enum TypeEnum {
|
||||
pub enum TypeEnum {
|
||||
TVar {
|
||||
// TODO: upper/lower bound
|
||||
id: u32,
|
||||
|
@ -95,7 +95,7 @@ enum TypeEnum {
|
|||
ty: Type,
|
||||
},
|
||||
TCall {
|
||||
calls: Vec<Call>,
|
||||
calls: Vec<Rc<Call>>,
|
||||
},
|
||||
TFunc {
|
||||
args: Vec<FuncArg>,
|
||||
|
@ -143,19 +143,40 @@ impl TypeEnum {
|
|||
}
|
||||
}
|
||||
|
||||
struct ObjDef {
|
||||
pub struct ObjDef {
|
||||
name: String,
|
||||
fields: Mapping<String>,
|
||||
}
|
||||
|
||||
struct Unifier {
|
||||
pub struct Unifier {
|
||||
unification_table: RefCell<InPlaceUnificationTable<Type>>,
|
||||
type_arena: RefCell<Arena<Rc<RefCell<TypeEnum>>>>,
|
||||
obj_def_table: Vec<ObjDef>,
|
||||
}
|
||||
|
||||
impl Unifier {
|
||||
fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> {
|
||||
pub fn new() -> Unifier {
|
||||
Unifier {
|
||||
unification_table: RefCell::new(InPlaceUnificationTable::new()),
|
||||
type_arena: RefCell::new(Arena::new()),
|
||||
obj_def_table: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_ty(&self, a: TypeEnum) -> Type {
|
||||
let index = self.type_arena.borrow_mut().insert(Rc::new(a.into()));
|
||||
self.unification_table
|
||||
.borrow_mut()
|
||||
.new_key(TypeIndex(index))
|
||||
}
|
||||
|
||||
pub fn get_ty(&self, a: Type) -> Rc<RefCell<TypeEnum>> {
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
let arena = self.type_arena.borrow();
|
||||
arena.get(table.probe_value(a).0).unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> {
|
||||
let (mut i_a, mut i_b) = {
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
(table.probe_value(a), table.probe_value(b))
|
||||
|
@ -186,38 +207,21 @@ impl Unifier {
|
|||
self.occur_check(i_a, b)?;
|
||||
match &*ty_a {
|
||||
TypeEnum::TVar { .. } => {
|
||||
match *ty_b {
|
||||
TypeEnum::TVar { .. } => {
|
||||
// TODO: type variables bound check
|
||||
let old = {
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
table.union(a, b);
|
||||
if table.find(a) == a {
|
||||
i_b
|
||||
} else {
|
||||
i_a
|
||||
}
|
||||
};
|
||||
self.type_arena.borrow_mut().remove(old.0);
|
||||
}
|
||||
_ => {
|
||||
// TODO: type variables bound check
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
}
|
||||
// TODO: type variables bound check...
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
TypeEnum::TSeq { map: map1 } => {
|
||||
match &*ty_b {
|
||||
TypeEnum::TSeq { map: map2 } => {
|
||||
drop(ty_a);
|
||||
if let TypeEnum::TSeq { map: map1 } = &mut *ty_a_cell.as_ref().borrow_mut()
|
||||
TypeEnum::TSeq { .. } => {
|
||||
drop(ty_b);
|
||||
if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut()
|
||||
{
|
||||
// unify them to map1
|
||||
for (key, value) in map2.iter() {
|
||||
if let Some(ty) = map1.get(key) {
|
||||
for (key, value) in map1.iter() {
|
||||
if let Some(ty) = map2.get(key) {
|
||||
self.unify(*ty, *value)?;
|
||||
} else {
|
||||
map1.insert(*key, *value);
|
||||
map2.insert(*key, *value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -277,16 +281,16 @@ impl Unifier {
|
|||
}
|
||||
TypeEnum::TRecord { fields: fields1 } => {
|
||||
match &*ty_b {
|
||||
TypeEnum::TRecord { fields: fields2 } => {
|
||||
drop(ty_a);
|
||||
if let TypeEnum::TRecord { fields: fields1 } =
|
||||
&mut *ty_a_cell.as_ref().borrow_mut()
|
||||
TypeEnum::TRecord { .. } => {
|
||||
drop(ty_b);
|
||||
if let TypeEnum::TRecord { fields: fields2 } =
|
||||
&mut *ty_b_cell.as_ref().borrow_mut()
|
||||
{
|
||||
for (key, value) in fields2.iter() {
|
||||
if let Some(ty) = fields1.get(key) {
|
||||
for (key, value) in fields1.iter() {
|
||||
if let Some(ty) = fields2.get(key) {
|
||||
self.unify(*ty, *value)?;
|
||||
} else {
|
||||
fields1.insert(key.clone(), *value);
|
||||
fields2.insert(key.clone(), *value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -341,6 +345,7 @@ impl Unifier {
|
|||
TypeEnum::TVirtual { ty: ty1 } => {
|
||||
if let TypeEnum::TVirtual { ty: ty2 } = &*ty_b {
|
||||
self.unify(*ty1, *ty2)?;
|
||||
self.set_a_to_b(a, b);
|
||||
} else {
|
||||
return self.report_kind_error(&*ty_a, &*ty_b);
|
||||
}
|
||||
|
@ -427,7 +432,7 @@ impl Unifier {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn subst(&self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
pub fn subst(&self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
let index = self.unification_table.borrow_mut().probe_value(a);
|
||||
let ty_cell = {
|
||||
let arena = self.type_arena.borrow();
|
||||
|
@ -459,34 +464,14 @@ impl Unifier {
|
|||
new_ty.as_mut().unwrap()[i] = t1;
|
||||
}
|
||||
}
|
||||
new_ty.map(|t| {
|
||||
let index = self
|
||||
.type_arena
|
||||
.borrow_mut()
|
||||
.insert(Rc::new(TypeEnum::TTuple { ty: t }.into()));
|
||||
self.unification_table
|
||||
.borrow_mut()
|
||||
.new_key(TypeIndex(index))
|
||||
})
|
||||
new_ty.map(|t| self.add_ty(TypeEnum::TTuple { ty: t }))
|
||||
}
|
||||
TypeEnum::TList { ty } => self.subst(*ty, mapping).map(|t| {
|
||||
let index = self
|
||||
.type_arena
|
||||
.borrow_mut()
|
||||
.insert(Rc::new(TypeEnum::TList { ty: t }.into()));
|
||||
self.unification_table
|
||||
.borrow_mut()
|
||||
.new_key(TypeIndex(index))
|
||||
}),
|
||||
TypeEnum::TVirtual { ty } => self.subst(*ty, mapping).map(|t| {
|
||||
let index = self
|
||||
.type_arena
|
||||
.borrow_mut()
|
||||
.insert(Rc::new(TypeEnum::TVirtual { ty: t }.into()));
|
||||
self.unification_table
|
||||
.borrow_mut()
|
||||
.new_key(TypeIndex(index))
|
||||
}),
|
||||
TypeEnum::TList { ty } => self
|
||||
.subst(*ty, mapping)
|
||||
.map(|t| self.add_ty(TypeEnum::TList { ty: t })),
|
||||
TypeEnum::TVirtual { ty } => self
|
||||
.subst(*ty, mapping)
|
||||
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
|
||||
TypeEnum::TObj {
|
||||
obj_id,
|
||||
fields,
|
||||
|
@ -508,23 +493,18 @@ impl Unifier {
|
|||
}
|
||||
});
|
||||
if need_subst {
|
||||
let index = self.type_arena.borrow_mut().insert(Rc::new(
|
||||
TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
params: self
|
||||
.subst_map(¶ms, mapping)
|
||||
.unwrap_or_else(|| params.clone()),
|
||||
fields: self
|
||||
.subst_map(&fields, mapping)
|
||||
.unwrap_or_else(|| fields.clone()),
|
||||
}
|
||||
.into(),
|
||||
));
|
||||
Some(
|
||||
self.unification_table
|
||||
.borrow_mut()
|
||||
.new_key(TypeIndex(index)),
|
||||
)
|
||||
let obj_id = *obj_id;
|
||||
let params = self
|
||||
.subst_map(¶ms, mapping)
|
||||
.unwrap_or_else(|| params.clone());
|
||||
let fields = self
|
||||
.subst_map(&fields, mapping)
|
||||
.unwrap_or_else(|| fields.clone());
|
||||
Some(self.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
params,
|
||||
fields,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -546,19 +526,10 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
if new_params.is_some() || new_ret.is_some() || new_args.is_some() {
|
||||
let index = self.type_arena.borrow_mut().insert(Rc::new(
|
||||
TypeEnum::TFunc {
|
||||
params: new_params.unwrap_or_else(|| params.clone()),
|
||||
ret: new_ret.unwrap_or_else(|| *ret),
|
||||
args: new_args.unwrap_or_else(|| args.clone()),
|
||||
}
|
||||
.into(),
|
||||
));
|
||||
Some(
|
||||
self.unification_table
|
||||
.borrow_mut()
|
||||
.new_key(TypeIndex(index)),
|
||||
)
|
||||
let params = new_params.unwrap_or_else(|| params.clone());
|
||||
let ret = new_ret.unwrap_or_else(|| *ret);
|
||||
let args = new_args.unwrap_or_else(|| args.clone());
|
||||
Some(self.add_ty(TypeEnum::TFunc { params, ret, args }))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -569,7 +540,7 @@ impl Unifier {
|
|||
|
||||
fn subst_map<K>(&self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
|
||||
where
|
||||
K: std::cmp::Ord + std::clone::Clone,
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
let mut map2 = None;
|
||||
for (k, v) in map.iter() {
|
||||
|
@ -582,4 +553,74 @@ impl Unifier {
|
|||
}
|
||||
map2
|
||||
}
|
||||
|
||||
pub fn eq(&self, a: Type, b: Type) -> bool {
|
||||
if a == b {
|
||||
return true;
|
||||
}
|
||||
let (i_a, i_b) = {
|
||||
let mut table = self.unification_table.borrow_mut();
|
||||
(table.probe_value(a), table.probe_value(b))
|
||||
};
|
||||
|
||||
if i_a == i_b {
|
||||
return true;
|
||||
}
|
||||
|
||||
let (ty_a, ty_b) = {
|
||||
let arena = self.type_arena.borrow();
|
||||
(
|
||||
arena.get(i_a.0).unwrap().clone(),
|
||||
arena.get(i_b.0).unwrap().clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let ty_a = ty_a.borrow();
|
||||
let ty_b = ty_b.borrow();
|
||||
|
||||
match (&*ty_a, &*ty_b) {
|
||||
(TypeEnum::TVar { id: id1 }, TypeEnum::TVar { id: id2 }) => id1 == id2,
|
||||
(TypeEnum::TSeq { map: map1 }, TypeEnum::TSeq { map: map2 }) => self.map_eq(map1, map2),
|
||||
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
|
||||
ty1.len() == ty2.len()
|
||||
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
|
||||
}
|
||||
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
|
||||
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
|
||||
self.eq(*ty1, *ty2)
|
||||
}
|
||||
(TypeEnum::TRecord { fields: fields1 }, TypeEnum::TRecord { fields: fields2 }) => {
|
||||
self.map_eq(fields1, fields2)
|
||||
}
|
||||
(
|
||||
TypeEnum::TObj {
|
||||
obj_id: id1,
|
||||
params: params1,
|
||||
..
|
||||
},
|
||||
TypeEnum::TObj {
|
||||
obj_id: id2,
|
||||
params: params2,
|
||||
..
|
||||
},
|
||||
) => id1 == id2 && self.map_eq(params1, params2),
|
||||
// TCall and TFunc are not yet implemented
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_eq<K>(&self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
if map1.len() != map2.len() {
|
||||
return false;
|
||||
}
|
||||
for (k, v) in map1.iter() {
|
||||
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue