forked from M-Labs/nac3
implementing inference
This commit is contained in:
parent
0bca238642
commit
fa02dc8271
107
nac3type/src/inference.rs
Normal file
107
nac3type/src/inference.rs
Normal file
@ -0,0 +1,107 @@
|
||||
use super::types::{Type::*, *};
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
fn find_subst(
|
||||
ctx: &GlobalContext,
|
||||
assumptions: & HashMap<VariableId, Rc<Type>>,
|
||||
sub: &mut HashMap<VariableId, Rc<Type>>,
|
||||
mut a: Rc<Type>,
|
||||
mut b: Rc<Type>,
|
||||
) -> Result<(), String> {
|
||||
// TODO: fix error messages later
|
||||
if let TypeVariable(id) = a.as_ref() {
|
||||
if let Some(c) = assumptions.get(&id) {
|
||||
a = c.clone();
|
||||
}
|
||||
}
|
||||
|
||||
if let TypeVariable(id) = b.as_ref() {
|
||||
if let Some(c) = sub.get(&id) {
|
||||
b = c.clone();
|
||||
}
|
||||
}
|
||||
|
||||
match (a.as_ref(), b.as_ref()) {
|
||||
(BotType, _) => Ok(()),
|
||||
(TypeVariable(id_a), TypeVariable(id_b)) => {
|
||||
let v_a = ctx.get_variable(*id_a);
|
||||
let v_b = ctx.get_variable(*id_b);
|
||||
if v_b.bound.len() > 0 {
|
||||
if v_a.bound.len() == 0 {
|
||||
return Err("unbounded a".to_string());
|
||||
} else {
|
||||
let diff: Vec<_> = v_a
|
||||
.bound
|
||||
.iter()
|
||||
.filter(|x| !v_b.bound.contains(x))
|
||||
.collect();
|
||||
if diff.len() > 0 {
|
||||
return Err("different domain".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
sub.insert(*id_b, a.clone().into());
|
||||
Ok(())
|
||||
}
|
||||
(TypeVariable(id_a), _) => {
|
||||
let v_a = ctx.get_variable(*id_a);
|
||||
if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("different domain".to_string())
|
||||
}
|
||||
}
|
||||
(_, TypeVariable(id_b)) => {
|
||||
let v_b = ctx.get_variable(*id_b);
|
||||
if v_b.bound.len() == 0 || v_b.bound.contains(&a) {
|
||||
sub.insert(*id_b, a.clone().into());
|
||||
Ok(())
|
||||
} else {
|
||||
Err("different domain".to_string())
|
||||
}
|
||||
}
|
||||
(_, VirtualClassType(id_b)) => {
|
||||
let mut parents;
|
||||
match a.as_ref() {
|
||||
ClassType(id_a) => {
|
||||
parents = [*id_a].to_vec();
|
||||
}
|
||||
VirtualClassType(id_a) => {
|
||||
let a = ctx.get_class(*id_a);
|
||||
parents = a.parents.clone();
|
||||
}
|
||||
_ => {
|
||||
return Err("cannot substitute non-class type into virtual class".to_string());
|
||||
}
|
||||
};
|
||||
while !parents.is_empty() {
|
||||
if *id_b == parents[0] {
|
||||
return Ok(());
|
||||
}
|
||||
let c = ctx.get_class(parents.remove(0));
|
||||
parents.extend_from_slice(&c.parents);
|
||||
}
|
||||
Err("not subtype".to_string())
|
||||
},
|
||||
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
|
||||
if id_a != id_b || param_a.len() != param_b.len() {
|
||||
Err("different parametric types".to_string())
|
||||
} else {
|
||||
for (x, y) in param_a.iter().zip(param_b.iter()) {
|
||||
find_subst(ctx, assumptions, sub, x.clone(), y.clone())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
(_, _) => {
|
||||
if a == b {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("not equal".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
extern crate rustpython_parser;
|
||||
|
||||
|
||||
mod types;
|
||||
mod inference;
|
||||
|
||||
|
@ -1,25 +1,26 @@
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
pub struct PrimitiveId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
pub struct ClassId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
pub struct ParamId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
pub struct VariableId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
#[derive(PartialEq, Eq, Clone, Hash)]
|
||||
pub enum Type {
|
||||
BotType,
|
||||
SelfType,
|
||||
PrimitiveType(PrimitiveId),
|
||||
ClassType(ClassId),
|
||||
VirtualClassType(ClassId),
|
||||
ParametricType(ParamId, Vec<Type>),
|
||||
ParametricType(ParamId, Vec<Rc<Type>>),
|
||||
TypeVariable(VariableId),
|
||||
}
|
||||
|
||||
@ -80,7 +81,11 @@ impl<'a> GlobalContext<'a> {
|
||||
self.class_defs.push(def);
|
||||
}
|
||||
pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
|
||||
let params = def.params.iter().map(|&v| Type::TypeVariable(v)).collect();
|
||||
let params = def
|
||||
.params
|
||||
.iter()
|
||||
.map(|&v| Type::TypeVariable(v).into())
|
||||
.collect();
|
||||
self.sym_table.insert(
|
||||
def.base.name,
|
||||
Type::ParametricType(ParamId(self.parametric_defs.len()), params),
|
||||
@ -100,36 +105,36 @@ impl<'a> GlobalContext<'a> {
|
||||
self.var_defs.push(def);
|
||||
}
|
||||
|
||||
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> Option<&mut TypeDef<'a>> {
|
||||
self.primitive_defs.get_mut(id.0)
|
||||
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
|
||||
self.primitive_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_primitive(&self, id: PrimitiveId) -> Option<&TypeDef> {
|
||||
self.primitive_defs.get(id.0)
|
||||
pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef {
|
||||
self.primitive_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_class_mut(&mut self, id: ClassId) -> Option<&mut ClassDef<'a>> {
|
||||
self.class_defs.get_mut(id.0)
|
||||
pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
|
||||
self.class_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_class(&self, id: ClassId) -> Option<&ClassDef> {
|
||||
self.class_defs.get(id.0)
|
||||
pub fn get_class(&self, id: ClassId) -> &ClassDef {
|
||||
self.class_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_parametric_mut(&mut self, id: ParamId) -> Option<&mut ParametricDef<'a>> {
|
||||
self.parametric_defs.get_mut(id.0)
|
||||
pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
|
||||
self.parametric_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_parametric(&self, id: ParamId) -> Option<&ParametricDef> {
|
||||
self.parametric_defs.get(id.0)
|
||||
pub fn get_parametric(&self, id: ParamId) -> &ParametricDef {
|
||||
self.parametric_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_variable_mut(&mut self, id: VariableId) -> Option<&mut VarDef<'a>> {
|
||||
self.var_defs.get_mut(id.0)
|
||||
pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
|
||||
self.var_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_variable(&self, id: VariableId) -> Option<&VarDef> {
|
||||
self.var_defs.get(id.0)
|
||||
pub fn get_variable(&self, id: VariableId) -> &VarDef {
|
||||
self.var_defs.get(id.0).unwrap()
|
||||
}
|
||||
|
||||
pub fn get_type(&self, name: &str) -> Option<Type> {
|
||||
@ -139,41 +144,49 @@ impl<'a> GlobalContext<'a> {
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub fn subst(&self, map: &Option<HashMap<VariableId, Type>>) -> Type {
|
||||
if let Some(m) = map {
|
||||
match self {
|
||||
Type::TypeVariable(id) => m.get(id).unwrap_or(self).clone(),
|
||||
Type::ParametricType(id, params) => {
|
||||
Type::ParametricType(*id, params.iter().map(|v| v.subst(map)).collect())
|
||||
}
|
||||
_ => self.clone(),
|
||||
}
|
||||
} else {
|
||||
self.clone()
|
||||
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> Type {
|
||||
match self {
|
||||
Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(),
|
||||
Type::ParametricType(id, params) => Type::ParametricType(
|
||||
*id,
|
||||
params
|
||||
.iter()
|
||||
.map(|v| v.as_ref().subst(map).into())
|
||||
.collect(),
|
||||
),
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type {
|
||||
for (from, to) in map.iter() {
|
||||
if self == from {
|
||||
return to.clone()
|
||||
return to.clone();
|
||||
}
|
||||
}
|
||||
match self {
|
||||
Type::ParametricType(id, params) => {
|
||||
Type::ParametricType(*id, params.iter().map(|v| v.inv_subst(map)).collect())
|
||||
},
|
||||
_ => self.clone()
|
||||
Type::ParametricType(id, params) => Type::ParametricType(
|
||||
*id,
|
||||
params
|
||||
.iter()
|
||||
.map(|v| v.as_ref().inv_subst(map).into())
|
||||
.collect(),
|
||||
),
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_subst(&self, ctx: &GlobalContext) -> Option<HashMap<VariableId, Type>> {
|
||||
pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Type> {
|
||||
match self {
|
||||
Type::ParametricType(id, params) => {
|
||||
let vars = &ctx.get_parametric(*id).unwrap().params;
|
||||
Some(vars.iter().zip(params).map(|(v, p)| (*v, p.clone())).collect())
|
||||
},
|
||||
_ => None
|
||||
let vars = &ctx.get_parametric(*id).params;
|
||||
vars.iter()
|
||||
.zip(params)
|
||||
.map(|(v, p)| (*v, p.as_ref().clone()))
|
||||
.collect()
|
||||
}
|
||||
// if this proves to be slow, we can use option type
|
||||
_ => HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user