diff --git a/nac3core/src/typecheck/test_typedef.rs b/nac3core/src/typecheck/test_typedef.rs index b61bdb99..3bce23a0 100644 --- a/nac3core/src/typecheck/test_typedef.rs +++ b/nac3core/src/typecheck/test_typedef.rs @@ -8,14 +8,12 @@ mod test { struct TestEnvironment { pub unifier: Unifier, type_mapping: HashMap, - var_max_id: u32, } impl TestEnvironment { fn new() -> TestEnvironment { - let unifier = Unifier::new(); + let mut unifier = Unifier::new(); let mut type_mapping = HashMap::new(); - let mut var_max_id = 0; type_mapping.insert( "int".into(), @@ -41,38 +39,30 @@ mod test { params: HashMap::new(), }), ); - let v0 = unifier.add_ty(TypeEnum::TVar { id: 0 }); - var_max_id += 1; + let (v0, id) = unifier.get_fresh_var(); type_mapping.insert( "Foo".into(), unifier.add_ty(TypeEnum::TObj { obj_id: 3, fields: [("a".into(), v0)].iter().cloned().collect(), - params: [(0u32, v0)].iter().cloned().collect(), + params: [(id, v0)].iter().cloned().collect(), }), ); TestEnvironment { unifier, type_mapping, - var_max_id, } } - fn get_fresh_var(&mut self) -> Type { - let id = self.var_max_id + 1; - self.var_max_id += 1; - self.unifier.add_ty(TypeEnum::TVar { id }) - } - - fn parse(&self, typ: &str, mapping: &Mapping) -> Type { + fn parse(&mut self, typ: &str, mapping: &Mapping) -> Type { let result = self.internal_parse(typ, mapping); assert!(result.1.is_empty()); result.0 } fn internal_parse<'a, 'b>( - &'a self, + &'a mut self, typ: &'b str, mapping: &Mapping, ) -> (Type, &'b str) { @@ -189,8 +179,8 @@ mod test { let mut env = TestEnvironment::new(); let mut mapping = HashMap::new(); for i in 1..=variable_count { - let v = env.get_fresh_var(); - mapping.insert(format!("v{}", i), v); + let v = env.unifier.get_fresh_var(); + mapping.insert(format!("v{}", i), v.0); } // unification may have side effect when we do type resolution, so freeze the types // before doing unification. @@ -259,8 +249,8 @@ mod test { let mut env = TestEnvironment::new(); let mut mapping = HashMap::new(); for i in 1..=variable_count { - let v = env.get_fresh_var(); - mapping.insert(format!("v{}", i), v); + let v = env.unifier.get_fresh_var(); + mapping.insert(format!("v{}", i), v.0); } // unification may have side effect when we do type resolution, so freeze the types // before doing unification. diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index b7fd33a6..6ff93879 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -48,7 +48,7 @@ impl Deref for TypeCell { } pub type Mapping = HashMap; -pub type VarMap = Mapping; +type VarMap = Mapping; #[derive(Clone)] pub struct Call { @@ -65,6 +65,13 @@ pub struct FuncArg { is_optional: bool, } +#[derive(Clone)] +pub struct FunSignature { + args: Vec, + ret: Type, + params: VarMap, +} + // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. // We may not really need so much `Rc`s, but we would have to do complicated // stuffs otherwise. @@ -96,11 +103,7 @@ pub enum TypeEnum { TCall { calls: Vec>, }, - TFunc { - args: Vec, - ret: Type, - params: VarMap, - }, + TFunc(FunSignature), } // Order: @@ -199,40 +202,41 @@ pub struct ObjDef { } pub struct Unifier { - unification_table: RefCell>, + unification_table: InPlaceUnificationTable, obj_def_table: Vec, + var_id: u32, } impl Unifier { pub fn new() -> Unifier { Unifier { - unification_table: RefCell::new(InPlaceUnificationTable::new()), + unification_table: InPlaceUnificationTable::new(), obj_def_table: Vec::new(), + var_id: 0, } } /// Register a type to the unifier. /// Returns a key in the unification_table. - pub fn add_ty(&self, a: TypeEnum) -> Type { - self.unification_table - .borrow_mut() - .new_key(TypeCell(Rc::new(a.into()))) + pub fn add_ty(&mut self, a: TypeEnum) -> Type { + self.unification_table.new_key(TypeCell(Rc::new(a.into()))) } /// Get the TypeEnum of a type. - pub fn get_ty(&self, a: Type) -> Rc> { - let mut table = self.unification_table.borrow_mut(); - table.probe_value(a).0 + pub fn get_ty(&mut self, a: Type) -> Rc> { + self.unification_table.probe_value(a).0 } /// Unify two types, i.e. a = b. - pub fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> { + pub fn unify(&mut self, mut a: Type, mut b: Type) -> Result<(), String> { let (mut ty_a_cell, mut ty_b_cell) = { - let mut table = self.unification_table.borrow_mut(); - if table.unioned(a, b) { + if self.unification_table.unioned(a, b) { return Ok(()); } - (table.probe_value(a), table.probe_value(b)) + ( + self.unification_table.probe_value(a), + self.unification_table.probe_value(b), + ) }; let (ty_a, ty_b) = { @@ -353,7 +357,6 @@ impl Unifier { TypeEnum::TVirtual { ty } => { // not sure if this is correct... self.unify(a, *ty)?; - self.set_a_to_b(a, b); } _ => { return self.incompatible_types(&*ty_a, &*ty_b); @@ -390,14 +393,105 @@ impl Unifier { return self.incompatible_types(&*ty_a, &*ty_b); } } - _ => unimplemented!(), + TypeEnum::TCall { calls: c1 } => match &*ty_b { + TypeEnum::TCall { .. } => { + drop(ty_b); + if let TypeEnum::TCall { calls: c2 } = &mut *ty_b_cell.as_ref().borrow_mut() { + c2.extend(c1.iter().cloned()); + } else { + unreachable!() + } + self.set_a_to_b(a, b); + } + TypeEnum::TFunc(signature) => { + let required: Vec = signature + .args + .iter() + .filter(|v| !v.is_optional) + .map(|v| v.name.clone()) + .rev() + .collect(); + for c in c1 { + let Call { + posargs, + kwargs, + ret, + fun, + } = c.as_ref(); + let instantiated = self.instantiate_fun(b, signature); + let signature; + let r = self.get_ty(instantiated); + let r = r.as_ref().borrow(); + if let TypeEnum::TFunc(s) = &*r { + signature = s; + } else { + unreachable!(); + } + let mut required = required.clone(); + let mut all_names: Vec<_> = signature + .args + .iter() + .map(|v| (v.name.clone(), v.ty)) + .rev() + .collect(); + for (i, t) in posargs.iter().enumerate() { + if signature.args.len() <= i { + return Err(format!("Too many arguments.")); + } + if !required.is_empty() { + required.pop(); + } + self.unify(all_names.pop().unwrap().1, *t)?; + } + for (k, t) in kwargs.iter() { + if let Some(i) = required.iter().position(|v| v == k) { + required.remove(i); + } + if let Some(i) = all_names.iter().position(|v| &v.0 == k) { + self.unify(all_names.remove(i).1, *t)?; + } else { + return Err(format!("Unknown keyword argument {}", k)); + } + } + self.unify(*ret, signature.ret)?; + *fun.borrow_mut() = Some(instantiated); + } + self.set_a_to_b(a, b); + } + _ => { + return self.incompatible_types(&*ty_a, &*ty_b); + } + }, + TypeEnum::TFunc(sign1) => { + if let TypeEnum::TFunc(sign2) = &*ty_b { + if !sign1.params.is_empty() || !sign2.params.is_empty() { + return Err(format!("Polymorphic function pointer is prohibited.")); + } + if sign1.args.len() != sign2.args.len() { + return Err(format!("Functions differ in number of parameters.")); + } + for (x, y) in sign1.args.iter().zip(sign2.args.iter()) { + if x.name != y.name { + return Err(format!("Functions differ in parameter names.")); + } + if x.is_optional != y.is_optional { + return Err(format!("Functions differ in optional parameters.")); + } + self.unify(x.ty, y.ty)?; + } + self.unify(sign1.ret, sign2.ret)?; + self.set_a_to_b(a, b); + } else { + return self.incompatible_types(&*ty_a, &*ty_b); + } + } } Ok(()) } - fn set_a_to_b(&self, a: Type, b: Type) { + fn set_a_to_b(&mut self, a: Type, b: Type) { // unify a and b together, and set the value to b's value. - let mut table = self.unification_table.borrow_mut(); + let table = &mut self.unification_table; let ty_b = table.probe_value(b); table.union(a, b); table.union_value(a, ty_b); @@ -411,11 +505,11 @@ impl Unifier { )) } - fn occur_check(&self, a: Type, b: Type) -> Result<(), String> { - if self.unification_table.borrow_mut().unioned(a, b) { + fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> { + if self.unification_table.unioned(a, b) { return Err("Recursive type is prohibited.".to_owned()); } - let ty = self.unification_table.borrow_mut().probe_value(b); + let ty = self.unification_table.probe_value(b); let ty = ty.borrow(); match &*ty { @@ -454,7 +548,7 @@ impl Unifier { self.occur_check(a, *t)?; } } - TypeEnum::TFunc { args, ret, params } => { + TypeEnum::TFunc(FunSignature { args, ret, params }) => { for t in args .iter() .map(|v| &v.ty) @@ -472,8 +566,8 @@ impl Unifier { /// If this returns Some(T), T would be the substituted type. /// If this returns None, the result type would be the original type /// (no substitution has to be done). - pub fn subst(&self, a: Type, mapping: &VarMap) -> Option { - let ty_cell = self.unification_table.borrow_mut().probe_value(a); + pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option { + let ty_cell = self.unification_table.probe_value(a); let ty = ty_cell.borrow(); // this function would only be called when we instantiate functions. // function type signature should ONLY contain concrete types and type @@ -512,10 +606,10 @@ impl Unifier { // parameter list, we don't need to substitute the fields. // This is also used to prevent infinite substitution... let need_subst = params.values().any(|v| { - let ty_cell = self.unification_table.borrow_mut().probe_value(*v); + let ty_cell = self.unification_table.probe_value(*v); let ty = ty_cell.borrow(); if let TypeEnum::TVar { id } = &*ty { - mapping.contains_key(id) + mapping.contains_key(&id) } else { false } @@ -537,7 +631,7 @@ impl Unifier { None } } - TypeEnum::TFunc { args, ret, params } => { + TypeEnum::TFunc(FunSignature { args, ret, params }) => { let new_params = self.subst_map(params, mapping); let new_ret = self.subst(*ret, mapping); let mut new_args = None; @@ -557,7 +651,7 @@ impl Unifier { let params = new_params.unwrap_or_else(|| params.clone()); let ret = new_ret.unwrap_or_else(|| *ret); let args = new_args.unwrap_or_else(|| args.clone()); - Some(self.add_ty(TypeEnum::TFunc { params, ret, args })) + Some(self.add_ty(TypeEnum::TFunc(FunSignature { params, ret, args }))) } else { None } @@ -566,7 +660,7 @@ impl Unifier { } } - fn subst_map(&self, map: &Mapping, mapping: &VarMap) -> Option> + fn subst_map(&mut self, map: &Mapping, mapping: &VarMap) -> Option> where K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, { @@ -582,13 +676,43 @@ impl Unifier { map2 } + /// Instantiate a function if it hasn't been instntiated. + /// Returns Some(T) where T is the instantiated type. + /// Returns None if the function is already instantiated. + fn instantiate_fun(&mut self, ty: Type, fun: &FunSignature) -> Type { + let mut instantiated = false; + for (k, v) in fun.params.iter() { + if let TypeEnum::TVar { id } = + &*self.unification_table.probe_value(*v).as_ref().borrow() + { + if k != id { + instantiated = true; + break; + } + } else { + instantiated = true; + break; + } + } + if instantiated { + ty + } else { + let mapping = fun + .params + .iter() + .map(|(k, _)| (*k, self.get_fresh_var().0)) + .collect(); + self.subst(ty, &mapping).unwrap_or(ty) + } + } + /// Check whether two types are equal. - pub fn eq(&self, a: Type, b: Type) -> bool { + pub fn eq(&mut self, a: Type, b: Type) -> bool { if a == b { return true; } let (ty_a, ty_b) = { - let mut table = self.unification_table.borrow_mut(); + let table = &mut self.unification_table; if table.unioned(a, b) { return true; } @@ -629,7 +753,7 @@ impl Unifier { } } - fn map_eq(&self, map1: &Mapping, map2: &Mapping) -> bool + fn map_eq(&mut self, map1: &Mapping, map2: &Mapping) -> bool where K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, { @@ -643,4 +767,11 @@ impl Unifier { } true } + + /// Get a fresh type variable. + pub fn get_fresh_var(&mut self) -> (Type, u32) { + let id = self.var_id + 1; + self.var_id += 1; + (self.add_ty(TypeEnum::TVar { id }), id) + } }