diff --git a/nac3type/src/inference.rs b/nac3type/src/inference.rs new file mode 100644 index 0000000000..9432938184 --- /dev/null +++ b/nac3type/src/inference.rs @@ -0,0 +1,107 @@ +use super::types::{Type::*, *}; +use std::collections::HashMap; +use std::rc::Rc; + +fn find_subst( + ctx: &GlobalContext, + assumptions: & HashMap>, + sub: &mut HashMap>, + mut a: Rc, + mut b: Rc, +) -> Result<(), String> { + // TODO: fix error messages later + if let TypeVariable(id) = a.as_ref() { + if let Some(c) = assumptions.get(&id) { + a = c.clone(); + } + } + + if let TypeVariable(id) = b.as_ref() { + if let Some(c) = sub.get(&id) { + b = c.clone(); + } + } + + match (a.as_ref(), b.as_ref()) { + (BotType, _) => Ok(()), + (TypeVariable(id_a), TypeVariable(id_b)) => { + let v_a = ctx.get_variable(*id_a); + let v_b = ctx.get_variable(*id_b); + if v_b.bound.len() > 0 { + if v_a.bound.len() == 0 { + return Err("unbounded a".to_string()); + } else { + let diff: Vec<_> = v_a + .bound + .iter() + .filter(|x| !v_b.bound.contains(x)) + .collect(); + if diff.len() > 0 { + return Err("different domain".to_string()); + } + } + } + sub.insert(*id_b, a.clone().into()); + Ok(()) + } + (TypeVariable(id_a), _) => { + let v_a = ctx.get_variable(*id_a); + if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() { + Ok(()) + } else { + Err("different domain".to_string()) + } + } + (_, TypeVariable(id_b)) => { + let v_b = ctx.get_variable(*id_b); + if v_b.bound.len() == 0 || v_b.bound.contains(&a) { + sub.insert(*id_b, a.clone().into()); + Ok(()) + } else { + Err("different domain".to_string()) + } + } + (_, VirtualClassType(id_b)) => { + let mut parents; + match a.as_ref() { + ClassType(id_a) => { + parents = [*id_a].to_vec(); + } + VirtualClassType(id_a) => { + let a = ctx.get_class(*id_a); + parents = a.parents.clone(); + } + _ => { + return Err("cannot substitute non-class type into virtual class".to_string()); + } + }; + while !parents.is_empty() { + if *id_b == parents[0] { + return Ok(()); + } + let c = ctx.get_class(parents.remove(0)); + parents.extend_from_slice(&c.parents); + } + Err("not subtype".to_string()) + }, + (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { + if id_a != id_b || param_a.len() != param_b.len() { + Err("different parametric types".to_string()) + } else { + for (x, y) in param_a.iter().zip(param_b.iter()) { + find_subst(ctx, assumptions, sub, x.clone(), y.clone())?; + } + Ok(()) + } + }, + (_, _) => { + if a == b { + Ok(()) + } else { + Err("not equal".to_string()) + } + } + } +} + + diff --git a/nac3type/src/lib.rs b/nac3type/src/lib.rs index fed80085ec..40217dcd46 100644 --- a/nac3type/src/lib.rs +++ b/nac3type/src/lib.rs @@ -1,5 +1,5 @@ extern crate rustpython_parser; - mod types; +mod inference; diff --git a/nac3type/src/types.rs b/nac3type/src/types.rs index 35b21bb345..adbfc70b37 100644 --- a/nac3type/src/types.rs +++ b/nac3type/src/types.rs @@ -1,25 +1,26 @@ use std::collections::HashMap; +use std::rc::Rc; -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Hash)] pub struct PrimitiveId(usize); -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Hash)] pub struct ClassId(usize); -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Hash)] pub struct ParamId(usize); #[derive(PartialEq, Eq, Copy, Clone, Hash)] pub struct VariableId(usize); -#[derive(PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone, Hash)] pub enum Type { BotType, SelfType, PrimitiveType(PrimitiveId), ClassType(ClassId), VirtualClassType(ClassId), - ParametricType(ParamId, Vec), + ParametricType(ParamId, Vec>), TypeVariable(VariableId), } @@ -80,7 +81,11 @@ impl<'a> GlobalContext<'a> { self.class_defs.push(def); } pub fn add_parametric(&mut self, def: ParametricDef<'a>) { - let params = def.params.iter().map(|&v| Type::TypeVariable(v)).collect(); + let params = def + .params + .iter() + .map(|&v| Type::TypeVariable(v).into()) + .collect(); self.sym_table.insert( def.base.name, Type::ParametricType(ParamId(self.parametric_defs.len()), params), @@ -100,36 +105,36 @@ impl<'a> GlobalContext<'a> { self.var_defs.push(def); } - pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> Option<&mut TypeDef<'a>> { - self.primitive_defs.get_mut(id.0) + pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { + self.primitive_defs.get_mut(id.0).unwrap() } - pub fn get_primitive(&self, id: PrimitiveId) -> Option<&TypeDef> { - self.primitive_defs.get(id.0) + pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef { + self.primitive_defs.get(id.0).unwrap() } - pub fn get_class_mut(&mut self, id: ClassId) -> Option<&mut ClassDef<'a>> { - self.class_defs.get_mut(id.0) + pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { + self.class_defs.get_mut(id.0).unwrap() } - pub fn get_class(&self, id: ClassId) -> Option<&ClassDef> { - self.class_defs.get(id.0) + pub fn get_class(&self, id: ClassId) -> &ClassDef { + self.class_defs.get(id.0).unwrap() } - pub fn get_parametric_mut(&mut self, id: ParamId) -> Option<&mut ParametricDef<'a>> { - self.parametric_defs.get_mut(id.0) + pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { + self.parametric_defs.get_mut(id.0).unwrap() } - pub fn get_parametric(&self, id: ParamId) -> Option<&ParametricDef> { - self.parametric_defs.get(id.0) + pub fn get_parametric(&self, id: ParamId) -> &ParametricDef { + self.parametric_defs.get(id.0).unwrap() } - pub fn get_variable_mut(&mut self, id: VariableId) -> Option<&mut VarDef<'a>> { - self.var_defs.get_mut(id.0) + pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { + self.var_defs.get_mut(id.0).unwrap() } - pub fn get_variable(&self, id: VariableId) -> Option<&VarDef> { - self.var_defs.get(id.0) + pub fn get_variable(&self, id: VariableId) -> &VarDef { + self.var_defs.get(id.0).unwrap() } pub fn get_type(&self, name: &str) -> Option { @@ -139,41 +144,49 @@ impl<'a> GlobalContext<'a> { } impl Type { - pub fn subst(&self, map: &Option>) -> Type { - if let Some(m) = map { - match self { - Type::TypeVariable(id) => m.get(id).unwrap_or(self).clone(), - Type::ParametricType(id, params) => { - Type::ParametricType(*id, params.iter().map(|v| v.subst(map)).collect()) - } - _ => self.clone(), - } - } else { - self.clone() + pub fn subst(&self, map: &HashMap) -> Type { + match self { + Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(), + Type::ParametricType(id, params) => Type::ParametricType( + *id, + params + .iter() + .map(|v| v.as_ref().subst(map).into()) + .collect(), + ), + _ => self.clone(), } } pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type { for (from, to) in map.iter() { if self == from { - return to.clone() + return to.clone(); } } match self { - Type::ParametricType(id, params) => { - Type::ParametricType(*id, params.iter().map(|v| v.inv_subst(map)).collect()) - }, - _ => self.clone() + Type::ParametricType(id, params) => Type::ParametricType( + *id, + params + .iter() + .map(|v| v.as_ref().inv_subst(map).into()) + .collect(), + ), + _ => self.clone(), } } - pub fn get_subst(&self, ctx: &GlobalContext) -> Option> { + pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap { match self { Type::ParametricType(id, params) => { - let vars = &ctx.get_parametric(*id).unwrap().params; - Some(vars.iter().zip(params).map(|(v, p)| (*v, p.clone())).collect()) - }, - _ => None + let vars = &ctx.get_parametric(*id).params; + vars.iter() + .zip(params) + .map(|(v, p)| (*v, p.as_ref().clone())) + .collect() + } + // if this proves to be slow, we can use option type + _ => HashMap::new(), } } }