forked from M-Labs/nac3
implementing inference
This commit is contained in:
parent
0bca238642
commit
fa02dc8271
107
nac3type/src/inference.rs
Normal file
107
nac3type/src/inference.rs
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
use super::types::{Type::*, *};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
fn find_subst(
|
||||||
|
ctx: &GlobalContext,
|
||||||
|
assumptions: & HashMap<VariableId, Rc<Type>>,
|
||||||
|
sub: &mut HashMap<VariableId, Rc<Type>>,
|
||||||
|
mut a: Rc<Type>,
|
||||||
|
mut b: Rc<Type>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// TODO: fix error messages later
|
||||||
|
if let TypeVariable(id) = a.as_ref() {
|
||||||
|
if let Some(c) = assumptions.get(&id) {
|
||||||
|
a = c.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let TypeVariable(id) = b.as_ref() {
|
||||||
|
if let Some(c) = sub.get(&id) {
|
||||||
|
b = c.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match (a.as_ref(), b.as_ref()) {
|
||||||
|
(BotType, _) => Ok(()),
|
||||||
|
(TypeVariable(id_a), TypeVariable(id_b)) => {
|
||||||
|
let v_a = ctx.get_variable(*id_a);
|
||||||
|
let v_b = ctx.get_variable(*id_b);
|
||||||
|
if v_b.bound.len() > 0 {
|
||||||
|
if v_a.bound.len() == 0 {
|
||||||
|
return Err("unbounded a".to_string());
|
||||||
|
} else {
|
||||||
|
let diff: Vec<_> = v_a
|
||||||
|
.bound
|
||||||
|
.iter()
|
||||||
|
.filter(|x| !v_b.bound.contains(x))
|
||||||
|
.collect();
|
||||||
|
if diff.len() > 0 {
|
||||||
|
return Err("different domain".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sub.insert(*id_b, a.clone().into());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
(TypeVariable(id_a), _) => {
|
||||||
|
let v_a = ctx.get_variable(*id_a);
|
||||||
|
if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err("different domain".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(_, TypeVariable(id_b)) => {
|
||||||
|
let v_b = ctx.get_variable(*id_b);
|
||||||
|
if v_b.bound.len() == 0 || v_b.bound.contains(&a) {
|
||||||
|
sub.insert(*id_b, a.clone().into());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err("different domain".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(_, VirtualClassType(id_b)) => {
|
||||||
|
let mut parents;
|
||||||
|
match a.as_ref() {
|
||||||
|
ClassType(id_a) => {
|
||||||
|
parents = [*id_a].to_vec();
|
||||||
|
}
|
||||||
|
VirtualClassType(id_a) => {
|
||||||
|
let a = ctx.get_class(*id_a);
|
||||||
|
parents = a.parents.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err("cannot substitute non-class type into virtual class".to_string());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
while !parents.is_empty() {
|
||||||
|
if *id_b == parents[0] {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let c = ctx.get_class(parents.remove(0));
|
||||||
|
parents.extend_from_slice(&c.parents);
|
||||||
|
}
|
||||||
|
Err("not subtype".to_string())
|
||||||
|
},
|
||||||
|
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
|
||||||
|
if id_a != id_b || param_a.len() != param_b.len() {
|
||||||
|
Err("different parametric types".to_string())
|
||||||
|
} else {
|
||||||
|
for (x, y) in param_a.iter().zip(param_b.iter()) {
|
||||||
|
find_subst(ctx, assumptions, sub, x.clone(), y.clone())?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(_, _) => {
|
||||||
|
if a == b {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err("not equal".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
extern crate rustpython_parser;
|
extern crate rustpython_parser;
|
||||||
|
|
||||||
|
|
||||||
mod types;
|
mod types;
|
||||||
|
mod inference;
|
||||||
|
|
||||||
|
@ -1,25 +1,26 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||||
pub struct PrimitiveId(usize);
|
pub struct PrimitiveId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||||
pub struct ClassId(usize);
|
pub struct ClassId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||||
pub struct ParamId(usize);
|
pub struct ParamId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||||
pub struct VariableId(usize);
|
pub struct VariableId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone)]
|
#[derive(PartialEq, Eq, Clone, Hash)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
BotType,
|
BotType,
|
||||||
SelfType,
|
SelfType,
|
||||||
PrimitiveType(PrimitiveId),
|
PrimitiveType(PrimitiveId),
|
||||||
ClassType(ClassId),
|
ClassType(ClassId),
|
||||||
VirtualClassType(ClassId),
|
VirtualClassType(ClassId),
|
||||||
ParametricType(ParamId, Vec<Type>),
|
ParametricType(ParamId, Vec<Rc<Type>>),
|
||||||
TypeVariable(VariableId),
|
TypeVariable(VariableId),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,7 +81,11 @@ impl<'a> GlobalContext<'a> {
|
|||||||
self.class_defs.push(def);
|
self.class_defs.push(def);
|
||||||
}
|
}
|
||||||
pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
|
pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
|
||||||
let params = def.params.iter().map(|&v| Type::TypeVariable(v)).collect();
|
let params = def
|
||||||
|
.params
|
||||||
|
.iter()
|
||||||
|
.map(|&v| Type::TypeVariable(v).into())
|
||||||
|
.collect();
|
||||||
self.sym_table.insert(
|
self.sym_table.insert(
|
||||||
def.base.name,
|
def.base.name,
|
||||||
Type::ParametricType(ParamId(self.parametric_defs.len()), params),
|
Type::ParametricType(ParamId(self.parametric_defs.len()), params),
|
||||||
@ -100,36 +105,36 @@ impl<'a> GlobalContext<'a> {
|
|||||||
self.var_defs.push(def);
|
self.var_defs.push(def);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> Option<&mut TypeDef<'a>> {
|
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
|
||||||
self.primitive_defs.get_mut(id.0)
|
self.primitive_defs.get_mut(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_primitive(&self, id: PrimitiveId) -> Option<&TypeDef> {
|
pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef {
|
||||||
self.primitive_defs.get(id.0)
|
self.primitive_defs.get(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_class_mut(&mut self, id: ClassId) -> Option<&mut ClassDef<'a>> {
|
pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> {
|
||||||
self.class_defs.get_mut(id.0)
|
self.class_defs.get_mut(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_class(&self, id: ClassId) -> Option<&ClassDef> {
|
pub fn get_class(&self, id: ClassId) -> &ClassDef {
|
||||||
self.class_defs.get(id.0)
|
self.class_defs.get(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_parametric_mut(&mut self, id: ParamId) -> Option<&mut ParametricDef<'a>> {
|
pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> {
|
||||||
self.parametric_defs.get_mut(id.0)
|
self.parametric_defs.get_mut(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_parametric(&self, id: ParamId) -> Option<&ParametricDef> {
|
pub fn get_parametric(&self, id: ParamId) -> &ParametricDef {
|
||||||
self.parametric_defs.get(id.0)
|
self.parametric_defs.get(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_variable_mut(&mut self, id: VariableId) -> Option<&mut VarDef<'a>> {
|
pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> {
|
||||||
self.var_defs.get_mut(id.0)
|
self.var_defs.get_mut(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_variable(&self, id: VariableId) -> Option<&VarDef> {
|
pub fn get_variable(&self, id: VariableId) -> &VarDef {
|
||||||
self.var_defs.get(id.0)
|
self.var_defs.get(id.0).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_type(&self, name: &str) -> Option<Type> {
|
pub fn get_type(&self, name: &str) -> Option<Type> {
|
||||||
@ -139,41 +144,49 @@ impl<'a> GlobalContext<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Type {
|
impl Type {
|
||||||
pub fn subst(&self, map: &Option<HashMap<VariableId, Type>>) -> Type {
|
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> Type {
|
||||||
if let Some(m) = map {
|
|
||||||
match self {
|
match self {
|
||||||
Type::TypeVariable(id) => m.get(id).unwrap_or(self).clone(),
|
Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(),
|
||||||
Type::ParametricType(id, params) => {
|
Type::ParametricType(id, params) => Type::ParametricType(
|
||||||
Type::ParametricType(*id, params.iter().map(|v| v.subst(map)).collect())
|
*id,
|
||||||
}
|
params
|
||||||
|
.iter()
|
||||||
|
.map(|v| v.as_ref().subst(map).into())
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
_ => self.clone(),
|
_ => self.clone(),
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
self.clone()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type {
|
pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type {
|
||||||
for (from, to) in map.iter() {
|
for (from, to) in map.iter() {
|
||||||
if self == from {
|
if self == from {
|
||||||
return to.clone()
|
return to.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match self {
|
match self {
|
||||||
Type::ParametricType(id, params) => {
|
Type::ParametricType(id, params) => Type::ParametricType(
|
||||||
Type::ParametricType(*id, params.iter().map(|v| v.inv_subst(map)).collect())
|
*id,
|
||||||
},
|
params
|
||||||
_ => self.clone()
|
.iter()
|
||||||
|
.map(|v| v.as_ref().inv_subst(map).into())
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
_ => self.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_subst(&self, ctx: &GlobalContext) -> Option<HashMap<VariableId, Type>> {
|
pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Type> {
|
||||||
match self {
|
match self {
|
||||||
Type::ParametricType(id, params) => {
|
Type::ParametricType(id, params) => {
|
||||||
let vars = &ctx.get_parametric(*id).unwrap().params;
|
let vars = &ctx.get_parametric(*id).params;
|
||||||
Some(vars.iter().zip(params).map(|(v, p)| (*v, p.clone())).collect())
|
vars.iter()
|
||||||
},
|
.zip(params)
|
||||||
_ => None
|
.map(|(v, p)| (*v, p.as_ref().clone()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
// if this proves to be slow, we can use option type
|
||||||
|
_ => HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user