forked from M-Labs/nac3
implemented inference
rc nightmare...
This commit is contained in:
parent
fa02dc8271
commit
fd3e1d4923
@ -46,7 +46,7 @@ fn find_subst(
|
|||||||
}
|
}
|
||||||
(TypeVariable(id_a), _) => {
|
(TypeVariable(id_a), _) => {
|
||||||
let v_a = ctx.get_variable(*id_a);
|
let v_a = ctx.get_variable(*id_a);
|
||||||
if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() {
|
if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
Err("different domain".to_string())
|
Err("different domain".to_string())
|
||||||
@ -83,7 +83,7 @@ fn find_subst(
|
|||||||
parents.extend_from_slice(&c.parents);
|
parents.extend_from_slice(&c.parents);
|
||||||
}
|
}
|
||||||
Err("not subtype".to_string())
|
Err("not subtype".to_string())
|
||||||
},
|
}
|
||||||
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
|
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
|
||||||
if id_a != id_b || param_a.len() != param_b.len() {
|
if id_a != id_b || param_a.len() != param_b.len() {
|
||||||
Err("different parametric types".to_string())
|
Err("different parametric types".to_string())
|
||||||
@ -93,7 +93,7 @@ fn find_subst(
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
(_, _) => {
|
(_, _) => {
|
||||||
if a == b {
|
if a == b {
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -104,4 +104,68 @@ fn find_subst(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn resolve_call(
|
||||||
|
ctx: &GlobalContext,
|
||||||
|
obj: Option<Rc<Type>>,
|
||||||
|
func: &str,
|
||||||
|
args: Rc<Type>,
|
||||||
|
assumptions: &mut HashMap<VariableId, Rc<Type>>,
|
||||||
|
) -> Result<Option<Rc<Type>>, String> {
|
||||||
|
let obj = obj.as_ref().map(|v| Rc::new(v.subst(assumptions)));
|
||||||
|
let mut subst = obj
|
||||||
|
.as_ref()
|
||||||
|
.map(|v| v.get_subst(ctx))
|
||||||
|
.unwrap_or(HashMap::new());
|
||||||
|
|
||||||
|
let fun = match &obj {
|
||||||
|
Some(obj) => {
|
||||||
|
let base = match obj.as_ref() {
|
||||||
|
TypeVariable(id) => {
|
||||||
|
let v = ctx.get_variable(*id);
|
||||||
|
if v.bound.len() == 0 {
|
||||||
|
return Err("unbounded type var".to_string());
|
||||||
|
}
|
||||||
|
let results: Result<Vec<_>, String> = v
|
||||||
|
.bound
|
||||||
|
.iter()
|
||||||
|
.map(|ins| {
|
||||||
|
assumptions.insert(*id, ins.clone());
|
||||||
|
resolve_call(ctx, Some(obj.clone()), func, args.clone(), assumptions)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let results = results?;
|
||||||
|
if results.iter().all(|v| v == &results[0]) {
|
||||||
|
return Ok(results[0].clone());
|
||||||
|
}
|
||||||
|
let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| {
|
||||||
|
r.as_ref()
|
||||||
|
.map(|v| v.inv_subst(&[(ins.clone(), obj.clone().into())]))
|
||||||
|
});
|
||||||
|
let first = results.next().unwrap();
|
||||||
|
if results.all(|v| v == first) {
|
||||||
|
return Ok(first);
|
||||||
|
} else {
|
||||||
|
return Err("divergent type after substitution".to_string());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
PrimitiveType(id) => &ctx.get_primitive(*id),
|
||||||
|
ClassType(id) | VirtualClassType(id) => &ctx.get_class(*id).base,
|
||||||
|
ParametricType(id, _) => &ctx.get_parametric(*id).base,
|
||||||
|
_ => return Err("not supported".to_string()),
|
||||||
|
};
|
||||||
|
base.methods.get(func)
|
||||||
|
},
|
||||||
|
None => ctx.get_fn(func),
|
||||||
|
}
|
||||||
|
.ok_or("no such function".to_string())?;
|
||||||
|
|
||||||
|
find_subst(ctx, assumptions, &mut subst, args, fun.args.clone())?;
|
||||||
|
let result = fun.result.as_ref().map(|v| v.subst(&subst));
|
||||||
|
Ok(result.map(|result| {
|
||||||
|
if let SelfType = result {
|
||||||
|
obj.unwrap()
|
||||||
|
} else {
|
||||||
|
result.into()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||||
pub struct PrimitiveId(usize);
|
pub struct PrimitiveId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||||
pub struct ClassId(usize);
|
pub struct ClassId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||||
pub struct ParamId(usize);
|
pub struct ParamId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||||
pub struct VariableId(usize);
|
pub struct VariableId(usize);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Hash)]
|
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
BotType,
|
BotType,
|
||||||
SelfType,
|
SelfType,
|
||||||
@ -25,8 +25,8 @@ pub enum Type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct FnDef {
|
pub struct FnDef {
|
||||||
pub args: Vec<Type>,
|
pub args: Rc<Type>,
|
||||||
pub result: Option<Type>,
|
pub result: Option<Rc<Type>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TypeDef<'a> {
|
pub struct TypeDef<'a> {
|
||||||
@ -47,15 +47,22 @@ pub struct ParametricDef<'a> {
|
|||||||
|
|
||||||
pub struct VarDef<'a> {
|
pub struct VarDef<'a> {
|
||||||
pub name: &'a str,
|
pub name: &'a str,
|
||||||
pub bound: Vec<Type>,
|
pub bound: Vec<Rc<Type>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const TUPLE_TYPE: ParamId = ParamId(0);
|
||||||
|
pub const LIST_TYPE: ParamId = ParamId(1);
|
||||||
|
|
||||||
|
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
|
||||||
|
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
|
||||||
|
|
||||||
pub struct GlobalContext<'a> {
|
pub struct GlobalContext<'a> {
|
||||||
primitive_defs: Vec<TypeDef<'a>>,
|
primitive_defs: Vec<TypeDef<'a>>,
|
||||||
class_defs: Vec<ClassDef<'a>>,
|
class_defs: Vec<ClassDef<'a>>,
|
||||||
parametric_defs: Vec<ParametricDef<'a>>,
|
parametric_defs: Vec<ParametricDef<'a>>,
|
||||||
var_defs: Vec<VarDef<'a>>,
|
var_defs: Vec<VarDef<'a>>,
|
||||||
sym_table: HashMap<&'a str, Type>,
|
sym_table: HashMap<&'a str, Type>,
|
||||||
|
fn_table: HashMap<&'a str, FnDef>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> GlobalContext<'a> {
|
impl<'a> GlobalContext<'a> {
|
||||||
@ -69,6 +76,7 @@ impl<'a> GlobalContext<'a> {
|
|||||||
class_defs: Vec::new(),
|
class_defs: Vec::new(),
|
||||||
parametric_defs: Vec::new(),
|
parametric_defs: Vec::new(),
|
||||||
var_defs: Vec::new(),
|
var_defs: Vec::new(),
|
||||||
|
fn_table: HashMap::new(),
|
||||||
sym_table,
|
sym_table,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -80,11 +88,12 @@ impl<'a> GlobalContext<'a> {
|
|||||||
);
|
);
|
||||||
self.class_defs.push(def);
|
self.class_defs.push(def);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
|
pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
|
||||||
let params = def
|
let params = def
|
||||||
.params
|
.params
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&v| Type::TypeVariable(v).into())
|
.map(|&v| Rc::new(Type::TypeVariable(v)))
|
||||||
.collect();
|
.collect();
|
||||||
self.sym_table.insert(
|
self.sym_table.insert(
|
||||||
def.base.name,
|
def.base.name,
|
||||||
@ -105,6 +114,14 @@ impl<'a> GlobalContext<'a> {
|
|||||||
self.var_defs.push(def);
|
self.var_defs.push(def);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add_fn(&'a mut self, name: &'a str, def: FnDef) {
|
||||||
|
self.fn_table.insert(name, def);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_fn(&self, name: &str) -> Option<&FnDef> {
|
||||||
|
self.fn_table.get(name)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
|
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
|
||||||
self.primitive_defs.get_mut(id.0).unwrap()
|
self.primitive_defs.get_mut(id.0).unwrap()
|
||||||
}
|
}
|
||||||
@ -144,9 +161,9 @@ impl<'a> GlobalContext<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Type {
|
impl Type {
|
||||||
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> Type {
|
pub fn subst(&self, map: &HashMap<VariableId, Rc<Type>>) -> Type {
|
||||||
match self {
|
match self {
|
||||||
Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(),
|
Type::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
|
||||||
Type::ParametricType(id, params) => Type::ParametricType(
|
Type::ParametricType(id, params) => Type::ParametricType(
|
||||||
*id,
|
*id,
|
||||||
params
|
params
|
||||||
@ -158,9 +175,9 @@ impl Type {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type {
|
pub fn inv_subst(&self, map: &[(Rc<Type>, Rc<Type>)]) -> Rc<Type> {
|
||||||
for (from, to) in map.iter() {
|
for (from, to) in map.iter() {
|
||||||
if self == from {
|
if self == from.as_ref() {
|
||||||
return to.clone();
|
return to.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -173,16 +190,16 @@ impl Type {
|
|||||||
.collect(),
|
.collect(),
|
||||||
),
|
),
|
||||||
_ => self.clone(),
|
_ => self.clone(),
|
||||||
}
|
}.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Type> {
|
pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Rc<Type>> {
|
||||||
match self {
|
match self {
|
||||||
Type::ParametricType(id, params) => {
|
Type::ParametricType(id, params) => {
|
||||||
let vars = &ctx.get_parametric(*id).params;
|
let vars = &ctx.get_parametric(*id).params;
|
||||||
vars.iter()
|
vars.iter()
|
||||||
.zip(params)
|
.zip(params)
|
||||||
.map(|(v, p)| (*v, p.as_ref().clone()))
|
.map(|(v, p)| (*v, p.as_ref().clone().into()))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
// if this proves to be slow, we can use option type
|
// if this proves to be slow, we can use option type
|
||||||
|
Loading…
Reference in New Issue
Block a user