diff --git a/nac3core/src/typecheck/test_typedef.rs b/nac3core/src/typecheck/test_typedef.rs index 3bce23a..d85fda7 100644 --- a/nac3core/src/typecheck/test_typedef.rs +++ b/nac3core/src/typecheck/test_typedef.rs @@ -206,7 +206,7 @@ mod test { ("v1", "Tuple[int]"), ("v2", "List[int]"), ], - (("v1", "v2"), "Cannot unify TTuple with TList") + (("v1", "v2"), "Cannot unify TList with TTuple") ; "type mismatch" )] #[test_case(2, @@ -222,7 +222,7 @@ mod test { ("v1", "Tuple[int,int]"), ("v2", "Tuple[int]"), ], - (("v1", "v2"), "Cannot unify tuples with length 1 and 2") + (("v1", "v2"), "Cannot unify tuples with length 2 and 1") ; "tuple length mismatch" )] #[test_case(3, diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index d2da32e..4c7eaa8 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -3,7 +3,6 @@ use std::cell::RefCell; use std::collections::HashMap; use std::fmt::Debug; use std::iter::once; -use std::mem::swap; use std::ops::Deref; use std::rc::Rc; @@ -69,7 +68,7 @@ pub struct FuncArg { pub struct FunSignature { pub args: Vec, pub ret: Type, - pub params: VarMap, + pub vars: VarMap, } // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. @@ -117,62 +116,7 @@ pub enum TypeEnum { // `--> TCall // `--> TFunc -// We encode the types as natural numbers, and subtyping relation as divisibility. -// If a | b, b <: a. -// We assign unique prime numbers (1 to TVar, everything is a subtype of it) to each type: -// TVar = 1 -// |--> TSeq = 2 -// | |--> TTuple = 3 -// | `--> TList = 5 -// |--> TRecord = 7 -// | |--> TObj = 11 -// | `--> TVirtual = 13 -// `--> TCall = 17 -// `--> TFunc = 21 -// -// And then, based on the subtyping relation, multiply them together... -// TVar = 1 -// |--> TSeq = 2 * TVar -// | |--> TTuple = 3 * TSeq * TVar -// | `--> TList = 5 * TSeq * TVar -// |--> TRecord = 7 * TVar -// | |--> TObj = 11 * TRecord * TVar -// | `--> TVirtual = 13 * TRecord * TVar -// `--> TCall = 17 * TVar -// `--> TFunc = 21 * TCall * TVar - impl TypeEnum { - fn get_int(&self) -> i32 { - const TVAR: i32 = 1; - const TSEQ: i32 = 2; - const TTUPLE: i32 = 3; - const TLIST: i32 = 5; - const TRECORD: i32 = 7; - const TOBJ: i32 = 11; - const TVIRTUAL: i32 = 13; - const TCALL: i32 = 17; - const TFUNC: i32 = 21; - - match self { - TypeEnum::TVar { .. } => TVAR, - TypeEnum::TSeq { .. } => TSEQ * TVAR, - TypeEnum::TTuple { .. } => TTUPLE * TSEQ * TVAR, - TypeEnum::TList { .. } => TLIST * TSEQ * TVAR, - TypeEnum::TRecord { .. } => TRECORD * TVAR, - TypeEnum::TObj { .. } => TOBJ * TRECORD * TVAR, - TypeEnum::TVirtual { .. } => TVIRTUAL * TRECORD * TVAR, - TypeEnum::TCall { .. } => TCALL * TVAR, - TypeEnum::TFunc { .. } => TFUNC * TCALL * TVAR, - } - } - - // e.g. List <: Var - pub fn type_le(&self, other: &TypeEnum) -> bool { - let a = self.get_int(); - let b = other.get_int(); - (a % b) == 0 - } - pub fn get_type_name(&self) -> &'static str { // this function is for debugging only... // a proper to_str implementation requires the context @@ -227,9 +171,14 @@ impl Unifier { self.unification_table.probe_value(a).0 } + pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> { + self.unify_impl(a, b, false) + } + /// Unify two types, i.e. a = b. - pub fn unify(&mut self, mut a: Type, mut b: Type) -> Result<(), String> { - let (mut ty_a_cell, mut ty_b_cell) = { + fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> { + use TypeEnum::*; + let (ty_a_cell, ty_b_cell) = { if self.unification_table.unioned(a, b) { return Ok(()); } @@ -240,251 +189,215 @@ impl Unifier { }; let (ty_a, ty_b) = { - // simplify our pattern matching... - if ty_a_cell.borrow().type_le(&ty_b_cell.borrow()) { - swap(&mut a, &mut b); - swap(&mut ty_a_cell, &mut ty_b_cell); - } (ty_a_cell.borrow(), ty_b_cell.borrow()) }; self.occur_check(a, b)?; - match &*ty_a { - TypeEnum::TVar { .. } => { - // TODO: type variables bound check... + match (&*ty_a, &*ty_b) { + (TypeEnum::TVar { .. }, _) => { self.set_a_to_b(a, b); } - TypeEnum::TSeq { map: map1 } => { - match &*ty_b { - TypeEnum::TSeq { .. } => { - drop(ty_b); - if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut() - { - // unify them to map2 - for (key, value) in map1.iter() { - if let Some(ty) = map2.get(key) { - self.unify(*ty, *value)?; - } else { - map2.insert(*key, *value); - } - } + (TSeq { map: map1 }, TSeq { .. }) => { + drop(ty_b); + if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut() { + // unify them to map2 + for (key, value) in map1.iter() { + if let Some(ty) = map2.get(key) { + self.unify(*ty, *value)?; } else { - unreachable!() + map2.insert(*key, *value); } - self.set_a_to_b(a, b); - } - TypeEnum::TTuple { ty: types } => { - let len = types.len() as i32; - for (k, v) in map1.iter() { - // handle negative index - let ind = if *k < 0 { len + *k } else { *k }; - if ind >= len || ind < 0 { - return Err(format!( - "Tuple index out of range. (Length: {}, Index: {})", - types.len(), - k - )); - } - self.unify(*v, types[ind as usize])?; - } - self.set_a_to_b(a, b); - } - TypeEnum::TList { ty } => { - for v in map1.values() { - self.unify(*v, *ty)?; - } - self.set_a_to_b(a, b); - } - _ => { - return self.incompatible_types(&*ty_a, &*ty_b); } + } else { + unreachable!() } + self.set_a_to_b(a, b); } - TypeEnum::TTuple { ty: ty1 } => { - if let TypeEnum::TTuple { ty: ty2 } = &*ty_b { - if ty1.len() != ty2.len() { + (TSeq { map: map1 }, TTuple { ty: types }) => { + let len = types.len() as i32; + for (k, v) in map1.iter() { + // handle negative index + let ind = if *k < 0 { len + *k } else { *k }; + if ind >= len || ind < 0 { return Err(format!( - "Cannot unify tuples with length {} and {}", - ty1.len(), - ty2.len() + "Tuple index out of range. (Length: {}, Index: {})", + types.len(), + k )); } - for (x, y) in ty1.iter().zip(ty2.iter()) { - self.unify(*x, *y)?; - } - self.set_a_to_b(a, b); - } else { - return self.incompatible_types(&*ty_a, &*ty_b); + self.unify(*v, types[ind as usize])?; } + self.set_a_to_b(a, b); } - TypeEnum::TList { ty: ty1 } => { - if let TypeEnum::TList { ty: ty2 } = *ty_b { - self.unify(*ty1, ty2)?; - self.set_a_to_b(a, b); - } else { - return self.incompatible_types(&*ty_a, &*ty_b); + (TSeq { map: map1 }, TList { ty }) => { + for v in map1.values() { + self.unify(*v, *ty)?; } + self.set_a_to_b(a, b); } - TypeEnum::TRecord { fields: fields1 } => { - match &*ty_b { - TypeEnum::TRecord { .. } => { - drop(ty_b); - if let TypeEnum::TRecord { fields: fields2 } = - &mut *ty_b_cell.as_ref().borrow_mut() - { - for (key, value) in fields1.iter() { - if let Some(ty) = fields2.get(key) { - self.unify(*ty, *value)?; - } else { - fields2.insert(key.clone(), *value); - } - } + (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { + if ty1.len() != ty2.len() { + return Err(format!( + "Cannot unify tuples with length {} and {}", + ty1.len(), + ty2.len() + )); + } + for (x, y) in ty1.iter().zip(ty2.iter()) { + self.unify(*x, *y)?; + } + self.set_a_to_b(a, b); + } + (TList { ty: ty1 }, TList { ty: ty2 }) => { + self.unify(*ty1, *ty2)?; + self.set_a_to_b(a, b); + } + (TRecord { fields: fields1 }, TRecord { .. }) => { + drop(ty_b); + if let TypeEnum::TRecord { fields: fields2 } = &mut *ty_b_cell.as_ref().borrow_mut() + { + for (key, value) in fields1.iter() { + if let Some(ty) = fields2.get(key) { + self.unify(*ty, *value)?; } else { - unreachable!() + fields2.insert(key.clone(), *value); } - self.set_a_to_b(a, b); } - TypeEnum::TObj { - fields: fields2, .. - } => { - for (key, value) in fields1.iter() { - if let Some(ty) = fields2.get(key) { - self.unify(*ty, *value)?; - } else { - return Err(format!("No such attribute {}", key)); - } - } - self.set_a_to_b(a, b); - } - TypeEnum::TVirtual { ty } => { - // not sure if this is correct... - self.unify(a, *ty)?; - } - _ => { - return self.incompatible_types(&*ty_a, &*ty_b); + } else { + unreachable!() + } + self.set_a_to_b(a, b); + } + ( + TRecord { fields: fields1 }, + TObj { + fields: fields2, .. + }, + ) => { + for (key, value) in fields1.iter() { + if let Some(ty) = fields2.get(key) { + self.unify(*ty, *value)?; + } else { + return Err(format!("No such attribute {}", key)); } } + self.set_a_to_b(a, b); } - TypeEnum::TObj { - obj_id: id1, - params: params1, - .. - } => { - if let TypeEnum::TObj { + (TRecord { .. }, TVirtual { ty }) => { + self.unify(a, *ty)?; + } + ( + TObj { + obj_id: id1, + params: params1, + .. + }, + TObj { obj_id: id2, params: params2, .. - } = &*ty_b - { - if id1 != id2 { - return Err(format!("Cannot unify objects with ID {} and {}", id1, id2)); - } - for (x, y) in params1.values().zip(params2.values()) { - self.unify(*x, *y)?; - } - self.set_a_to_b(a, b); - } else { - return self.incompatible_types(&*ty_a, &*ty_b); + }, + ) => { + if id1 != id2 { + return Err(format!("Cannot unify objects with ID {} and {}", id1, id2)); } - } - TypeEnum::TVirtual { ty: ty1 } => { - if let TypeEnum::TVirtual { ty: ty2 } = &*ty_b { - self.unify(*ty1, *ty2)?; - self.set_a_to_b(a, b); - } else { - return self.incompatible_types(&*ty_a, &*ty_b); + for (x, y) in params1.values().zip(params2.values()) { + self.unify(*x, *y)?; } + self.set_a_to_b(a, b); } - TypeEnum::TCall { calls: c1 } => match &*ty_b { - TypeEnum::TCall { .. } => { - drop(ty_b); - if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() { - c2.extend(c1.iter().cloned()); + (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { + self.unify(*ty1, *ty2)?; + self.set_a_to_b(a, b); + } + (TCall { calls: c1 }, TCall { .. }) => { + drop(ty_b); + if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() { + c2.extend(c1.iter().cloned()); + } else { + unreachable!() + } + self.set_a_to_b(a, b); + } + (TCall { calls }, TFunc(signature)) => { + let required: Vec = signature + .args + .iter() + .filter(|v| !v.is_optional) + .map(|v| v.name.clone()) + .rev() + .collect(); + for c in calls { + let Call { + posargs, + kwargs, + ret, + fun, + } = c.as_ref(); + let instantiated = self.instantiate_fun(b, signature); + let signature; + let r = self.get_ty(instantiated); + let r = r.as_ref().borrow(); + if let TypeEnum::TFunc(s) = &*r { + signature = s; } else { - unreachable!() + unreachable!(); } - self.set_a_to_b(a, b); - } - TypeEnum::TFunc(signature) => { - let required: Vec = signature + let mut required = required.clone(); + let mut all_names: Vec<_> = signature .args .iter() - .filter(|v| !v.is_optional) - .map(|v| v.name.clone()) + .map(|v| (v.name.clone(), v.ty)) .rev() .collect(); - for c in c1 { - let Call { - posargs, - kwargs, - ret, - fun, - } = c.as_ref(); - let instantiated = self.instantiate_fun(b, signature); - let signature; - let r = self.get_ty(instantiated); - let r = r.as_ref().borrow(); - if let TypeEnum::TFunc(s) = &*r { - signature = s; + for (i, t) in posargs.iter().enumerate() { + if signature.args.len() <= i { + return Err("Too many arguments.".to_string()); + } + if !required.is_empty() { + required.pop(); + } + self.unify(all_names.pop().unwrap().1, *t)?; + } + for (k, t) in kwargs.iter() { + if let Some(i) = required.iter().position(|v| v == k) { + required.remove(i); + } + if let Some(i) = all_names.iter().position(|v| &v.0 == k) { + self.unify(all_names.remove(i).1, *t)?; } else { - unreachable!(); + return Err(format!("Unknown keyword argument {}", k)); } - let mut required = required.clone(); - let mut all_names: Vec<_> = signature - .args - .iter() - .map(|v| (v.name.clone(), v.ty)) - .rev() - .collect(); - for (i, t) in posargs.iter().enumerate() { - if signature.args.len() <= i { - return Err("Too many arguments.".to_string()); - } - if !required.is_empty() { - required.pop(); - } - self.unify(all_names.pop().unwrap().1, *t)?; - } - for (k, t) in kwargs.iter() { - if let Some(i) = required.iter().position(|v| v == k) { - required.remove(i); - } - if let Some(i) = all_names.iter().position(|v| &v.0 == k) { - self.unify(all_names.remove(i).1, *t)?; - } else { - return Err(format!("Unknown keyword argument {}", k)); - } - } - self.unify(*ret, signature.ret)?; - *fun.borrow_mut() = Some(instantiated); } - self.set_a_to_b(a, b); + self.unify(*ret, signature.ret)?; + *fun.borrow_mut() = Some(instantiated); } - _ => { + self.set_a_to_b(a, b); + } + (TFunc(sign1), TFunc(sign2)) => { + if !sign1.vars.is_empty() || !sign2.vars.is_empty() { + return Err("Polymorphic function pointer is prohibited.".to_string()); + } + if sign1.args.len() != sign2.args.len() { + return Err("Functions differ in number of parameters.".to_string()); + } + for (x, y) in sign1.args.iter().zip(sign2.args.iter()) { + if x.name != y.name { + return Err("Functions differ in parameter names.".to_string()); + } + if x.is_optional != y.is_optional { + return Err("Functions differ in optional parameters.".to_string()); + } + self.unify(x.ty, y.ty)?; + } + self.unify(sign1.ret, sign2.ret)?; + self.set_a_to_b(a, b); + } + _ => { + if swapped { return self.incompatible_types(&*ty_a, &*ty_b); - } - }, - TypeEnum::TFunc(sign1) => { - if let TypeEnum::TFunc(sign2) = &*ty_b { - if !sign1.params.is_empty() || !sign2.params.is_empty() { - return Err("Polymorphic function pointer is prohibited.".to_string()); - } - if sign1.args.len() != sign2.args.len() { - return Err("Functions differ in number of parameters.".to_string()); - } - for (x, y) in sign1.args.iter().zip(sign2.args.iter()) { - if x.name != y.name { - return Err("Functions differ in parameter names.".to_string()); - } - if x.is_optional != y.is_optional { - return Err("Functions differ in optional parameters.".to_string()); - } - self.unify(x.ty, y.ty)?; - } - self.unify(sign1.ret, sign2.ret)?; - self.set_a_to_b(a, b); } else { - return self.incompatible_types(&*ty_a, &*ty_b); + self.unify_impl(b, a, true)?; } } } @@ -555,7 +468,11 @@ impl Unifier { self.occur_check(a, *t)?; } } - TypeEnum::TFunc(FunSignature { args, ret, params }) => { + TypeEnum::TFunc(FunSignature { + args, + ret, + vars: params, + }) => { for t in args .iter() .map(|v| &v.ty) @@ -638,7 +555,11 @@ impl Unifier { None } } - TypeEnum::TFunc(FunSignature { args, ret, params }) => { + TypeEnum::TFunc(FunSignature { + args, + ret, + vars: params, + }) => { let new_params = self.subst_map(params, mapping); let new_ret = self.subst(*ret, mapping); let mut new_args = None; @@ -658,7 +579,11 @@ impl Unifier { let params = new_params.unwrap_or_else(|| params.clone()); let ret = new_ret.unwrap_or_else(|| *ret); let args = new_args.unwrap_or_else(|| args.clone()); - Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, params }))) + Some(self.add_ty(TypeEnum::TFunc(FunSignature { + args, + ret, + vars: params, + }))) } else { None } @@ -688,7 +613,7 @@ impl Unifier { /// Returns None if the function is already instantiated. fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { let mut instantiated = false; - for (k, v) in fun.params.iter() { + for (k, v) in fun.vars.iter() { if let TypeEnum::TVar { id } = &*self.unification_table.probe_value(*v).as_ref().borrow() { @@ -705,7 +630,7 @@ impl Unifier { ty } else { let mapping = fun - .params + .vars .iter() .map(|(k, _)| (*k, self.get_fresh_var().0)) .collect();