implementing inference

This commit is contained in:
pca006132 2020-12-26 11:45:57 +08:00 committed by pca006132
parent 0bca238642
commit fa02dc8271
3 changed files with 164 additions and 44 deletions

107
nac3type/src/inference.rs Normal file
View File

@ -0,0 +1,107 @@
use super::types::{Type::*, *};
use std::collections::HashMap;
use std::rc::Rc;
fn find_subst(
ctx: &GlobalContext,
assumptions: & HashMap<VariableId, Rc<Type>>,
sub: &mut HashMap<VariableId, Rc<Type>>,
mut a: Rc<Type>,
mut b: Rc<Type>,
) -> Result<(), String> {
// TODO: fix error messages later
if let TypeVariable(id) = a.as_ref() {
if let Some(c) = assumptions.get(&id) {
a = c.clone();
}
}
if let TypeVariable(id) = b.as_ref() {
if let Some(c) = sub.get(&id) {
b = c.clone();
}
}
match (a.as_ref(), b.as_ref()) {
(BotType, _) => Ok(()),
(TypeVariable(id_a), TypeVariable(id_b)) => {
let v_a = ctx.get_variable(*id_a);
let v_b = ctx.get_variable(*id_b);
if v_b.bound.len() > 0 {
if v_a.bound.len() == 0 {
return Err("unbounded a".to_string());
} else {
let diff: Vec<_> = v_a
.bound
.iter()
.filter(|x| !v_b.bound.contains(x))
.collect();
if diff.len() > 0 {
return Err("different domain".to_string());
}
}
}
sub.insert(*id_b, a.clone().into());
Ok(())
}
(TypeVariable(id_a), _) => {
let v_a = ctx.get_variable(*id_a);
if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() {
Ok(())
} else {
Err("different domain".to_string())
}
}
(_, TypeVariable(id_b)) => {
let v_b = ctx.get_variable(*id_b);
if v_b.bound.len() == 0 || v_b.bound.contains(&a) {
sub.insert(*id_b, a.clone().into());
Ok(())
} else {
Err("different domain".to_string())
}
}
(_, VirtualClassType(id_b)) => {
let mut parents;
match a.as_ref() {
ClassType(id_a) => {
parents = [*id_a].to_vec();
}
VirtualClassType(id_a) => {
let a = ctx.get_class(*id_a);
parents = a.parents.clone();
}
_ => {
return Err("cannot substitute non-class type into virtual class".to_string());
}
};
while !parents.is_empty() {
if *id_b == parents[0] {
return Ok(());
}
let c = ctx.get_class(parents.remove(0));
parents.extend_from_slice(&c.parents);
}
Err("not subtype".to_string())
},
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
if id_a != id_b || param_a.len() != param_b.len() {
Err("different parametric types".to_string())
} else {
for (x, y) in param_a.iter().zip(param_b.iter()) {
find_subst(ctx, assumptions, sub, x.clone(), y.clone())?;
}
Ok(())
}
},
(_, _) => {
if a == b {
Ok(())
} else {
Err("not equal".to_string())
}
}
}
}

View File

@ -1,5 +1,5 @@
extern crate rustpython_parser; extern crate rustpython_parser;
mod types; mod types;
mod inference;

View File

@ -1,25 +1,26 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc;
#[derive(PartialEq, Eq, Copy, Clone)] #[derive(PartialEq, Eq, Copy, Clone, Hash)]
pub struct PrimitiveId(usize); pub struct PrimitiveId(usize);
#[derive(PartialEq, Eq, Copy, Clone)] #[derive(PartialEq, Eq, Copy, Clone, Hash)]
pub struct ClassId(usize); pub struct ClassId(usize);
#[derive(PartialEq, Eq, Copy, Clone)] #[derive(PartialEq, Eq, Copy, Clone, Hash)]
pub struct ParamId(usize); pub struct ParamId(usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash)] #[derive(PartialEq, Eq, Copy, Clone, Hash)]
pub struct VariableId(usize); pub struct VariableId(usize);
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone, Hash)]
pub enum Type { pub enum Type {
BotType, BotType,
SelfType, SelfType,
PrimitiveType(PrimitiveId), PrimitiveType(PrimitiveId),
ClassType(ClassId), ClassType(ClassId),
VirtualClassType(ClassId), VirtualClassType(ClassId),
ParametricType(ParamId, Vec<Type>), ParametricType(ParamId, Vec<Rc<Type>>),
TypeVariable(VariableId), TypeVariable(VariableId),
} }
@ -80,7 +81,11 @@ impl<'a> GlobalContext<'a> {
self.class_defs.push(def); self.class_defs.push(def);
} }
pub fn add_parametric(&mut self, def: ParametricDef<'a>) { pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
let params = def.params.iter().map(|&v| Type::TypeVariable(v)).collect(); let params = def
.params
.iter()
.map(|&v| Type::TypeVariable(v).into())
.collect();
self.sym_table.insert( self.sym_table.insert(
def.base.name, def.base.name,
Type::ParametricType(ParamId(self.parametric_defs.len()), params), Type::ParametricType(ParamId(self.parametric_defs.len()), params),
@ -100,36 +105,36 @@ impl<'a> GlobalContext<'a> {
self.var_defs.push(def); self.var_defs.push(def);
} }
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> Option<&mut TypeDef<'a>> { pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
self.primitive_defs.get_mut(id.0) self.primitive_defs.get_mut(id.0).unwrap()
} }
pub fn get_primitive(&self, id: PrimitiveId) -> Option<&TypeDef> { pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef {
self.primitive_defs.get(id.0) self.primitive_defs.get(id.0).unwrap()
} }
pub fn get_class_mut(&mut self, id: ClassId) -> Option<&mut ClassDef<'a>> { pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
self.class_defs.get_mut(id.0) self.class_defs.get_mut(id.0).unwrap()
} }
pub fn get_class(&self, id: ClassId) -> Option<&ClassDef> { pub fn get_class(&self, id: ClassId) -> &ClassDef {
self.class_defs.get(id.0) self.class_defs.get(id.0).unwrap()
} }
pub fn get_parametric_mut(&mut self, id: ParamId) -> Option<&mut ParametricDef<'a>> { pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
self.parametric_defs.get_mut(id.0) self.parametric_defs.get_mut(id.0).unwrap()
} }
pub fn get_parametric(&self, id: ParamId) -> Option<&ParametricDef> { pub fn get_parametric(&self, id: ParamId) -> &ParametricDef {
self.parametric_defs.get(id.0) self.parametric_defs.get(id.0).unwrap()
} }
pub fn get_variable_mut(&mut self, id: VariableId) -> Option<&mut VarDef<'a>> { pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
self.var_defs.get_mut(id.0) self.var_defs.get_mut(id.0).unwrap()
} }
pub fn get_variable(&self, id: VariableId) -> Option<&VarDef> { pub fn get_variable(&self, id: VariableId) -> &VarDef {
self.var_defs.get(id.0) self.var_defs.get(id.0).unwrap()
} }
pub fn get_type(&self, name: &str) -> Option<Type> { pub fn get_type(&self, name: &str) -> Option<Type> {
@ -139,41 +144,49 @@ impl<'a> GlobalContext<'a> {
} }
impl Type { impl Type {
pub fn subst(&self, map: &Option<HashMap<VariableId, Type>>) -> Type { pub fn subst(&self, map: &HashMap<VariableId, Type>) -> Type {
if let Some(m) = map {
match self { match self {
Type::TypeVariable(id) => m.get(id).unwrap_or(self).clone(), Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(),
Type::ParametricType(id, params) => { Type::ParametricType(id, params) => Type::ParametricType(
Type::ParametricType(*id, params.iter().map(|v| v.subst(map)).collect()) *id,
} params
.iter()
.map(|v| v.as_ref().subst(map).into())
.collect(),
),
_ => self.clone(), _ => self.clone(),
} }
} else {
self.clone()
}
} }
pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type { pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type {
for (from, to) in map.iter() { for (from, to) in map.iter() {
if self == from { if self == from {
return to.clone() return to.clone();
} }
} }
match self { match self {
Type::ParametricType(id, params) => { Type::ParametricType(id, params) => Type::ParametricType(
Type::ParametricType(*id, params.iter().map(|v| v.inv_subst(map)).collect()) *id,
}, params
_ => self.clone() .iter()
.map(|v| v.as_ref().inv_subst(map).into())
.collect(),
),
_ => self.clone(),
} }
} }
pub fn get_subst(&self, ctx: &GlobalContext) -> Option<HashMap<VariableId, Type>> { pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Type> {
match self { match self {
Type::ParametricType(id, params) => { Type::ParametricType(id, params) => {
let vars = &ctx.get_parametric(*id).unwrap().params; let vars = &ctx.get_parametric(*id).params;
Some(vars.iter().zip(params).map(|(v, p)| (*v, p.clone())).collect()) vars.iter()
}, .zip(params)
_ => None .map(|(v, p)| (*v, p.as_ref().clone()))
.collect()
}
// if this proves to be slow, we can use option type
_ => HashMap::new(),
} }
} }
} }