forked from M-Labs/nac3
implemented inference
rc nightmare...
This commit is contained in:
parent
fa02dc8271
commit
fd3e1d4923
@ -46,7 +46,7 @@ fn find_subst(
|
||||
}
|
||||
(TypeVariable(id_a), _) => {
|
||||
let v_a = ctx.get_variable(*id_a);
|
||||
if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() {
|
||||
if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("different domain".to_string())
|
||||
@ -83,7 +83,7 @@ fn find_subst(
|
||||
parents.extend_from_slice(&c.parents);
|
||||
}
|
||||
Err("not subtype".to_string())
|
||||
},
|
||||
}
|
||||
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
|
||||
if id_a != id_b || param_a.len() != param_b.len() {
|
||||
Err("different parametric types".to_string())
|
||||
@ -93,7 +93,7 @@ fn find_subst(
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
}
|
||||
(_, _) => {
|
||||
if a == b {
|
||||
Ok(())
|
||||
@ -104,4 +104,68 @@ fn find_subst(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_call(
|
||||
ctx: &GlobalContext,
|
||||
obj: Option<Rc<Type>>,
|
||||
func: &str,
|
||||
args: Rc<Type>,
|
||||
assumptions: &mut HashMap<VariableId, Rc<Type>>,
|
||||
) -> Result<Option<Rc<Type>>, String> {
|
||||
let obj = obj.as_ref().map(|v| Rc::new(v.subst(assumptions)));
|
||||
let mut subst = obj
|
||||
.as_ref()
|
||||
.map(|v| v.get_subst(ctx))
|
||||
.unwrap_or(HashMap::new());
|
||||
|
||||
let fun = match &obj {
|
||||
Some(obj) => {
|
||||
let base = match obj.as_ref() {
|
||||
TypeVariable(id) => {
|
||||
let v = ctx.get_variable(*id);
|
||||
if v.bound.len() == 0 {
|
||||
return Err("unbounded type var".to_string());
|
||||
}
|
||||
let results: Result<Vec<_>, String> = v
|
||||
.bound
|
||||
.iter()
|
||||
.map(|ins| {
|
||||
assumptions.insert(*id, ins.clone());
|
||||
resolve_call(ctx, Some(obj.clone()), func, args.clone(), assumptions)
|
||||
})
|
||||
.collect();
|
||||
let results = results?;
|
||||
if results.iter().all(|v| v == &results[0]) {
|
||||
return Ok(results[0].clone());
|
||||
}
|
||||
let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| {
|
||||
r.as_ref()
|
||||
.map(|v| v.inv_subst(&[(ins.clone(), obj.clone().into())]))
|
||||
});
|
||||
let first = results.next().unwrap();
|
||||
if results.all(|v| v == first) {
|
||||
return Ok(first);
|
||||
} else {
|
||||
return Err("divergent type after substitution".to_string());
|
||||
}
|
||||
},
|
||||
PrimitiveType(id) => &ctx.get_primitive(*id),
|
||||
ClassType(id) | VirtualClassType(id) => &ctx.get_class(*id).base,
|
||||
ParametricType(id, _) => &ctx.get_parametric(*id).base,
|
||||
_ => return Err("not supported".to_string()),
|
||||
};
|
||||
base.methods.get(func)
|
||||
},
|
||||
None => ctx.get_fn(func),
|
||||
}
|
||||
.ok_or("no such function".to_string())?;
|
||||
|
||||
find_subst(ctx, assumptions, &mut subst, args, fun.args.clone())?;
|
||||
let result = fun.result.as_ref().map(|v| v.subst(&subst));
|
||||
Ok(result.map(|result| {
|
||||
if let SelfType = result {
|
||||
obj.unwrap()
|
||||
} else {
|
||||
result.into()
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
@ -1,19 +1,19 @@
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct PrimitiveId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct ClassId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct ParamId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
|
||||
pub struct VariableId(usize);
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Hash)]
|
||||
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
|
||||
pub enum Type {
|
||||
BotType,
|
||||
SelfType,
|
||||
@ -25,8 +25,8 @@ pub enum Type {
|
||||
}
|
||||
|
||||
pub struct FnDef {
|
||||
pub args: Vec<Type>,
|
||||
pub result: Option<Type>,
|
||||
pub args: Rc<Type>,
|
||||
pub result: Option<Rc<Type>>,
|
||||
}
|
||||
|
||||
pub struct TypeDef<'a> {
|
||||
@ -47,15 +47,22 @@ pub struct ParametricDef<'a> {
|
||||
|
||||
pub struct VarDef<'a> {
|
||||
pub name: &'a str,
|
||||
pub bound: Vec<Type>,
|
||||
pub bound: Vec<Rc<Type>>,
|
||||
}
|
||||
|
||||
pub const TUPLE_TYPE: ParamId = ParamId(0);
|
||||
pub const LIST_TYPE: ParamId = ParamId(1);
|
||||
|
||||
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
|
||||
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
|
||||
|
||||
pub struct GlobalContext<'a> {
|
||||
primitive_defs: Vec<TypeDef<'a>>,
|
||||
class_defs: Vec<ClassDef<'a>>,
|
||||
parametric_defs: Vec<ParametricDef<'a>>,
|
||||
var_defs: Vec<VarDef<'a>>,
|
||||
sym_table: HashMap<&'a str, Type>,
|
||||
fn_table: HashMap<&'a str, FnDef>,
|
||||
}
|
||||
|
||||
impl<'a> GlobalContext<'a> {
|
||||
@ -69,6 +76,7 @@ impl<'a> GlobalContext<'a> {
|
||||
class_defs: Vec::new(),
|
||||
parametric_defs: Vec::new(),
|
||||
var_defs: Vec::new(),
|
||||
fn_table: HashMap::new(),
|
||||
sym_table,
|
||||
};
|
||||
}
|
||||
@ -80,11 +88,12 @@ impl<'a> GlobalContext<'a> {
|
||||
);
|
||||
self.class_defs.push(def);
|
||||
}
|
||||
|
||||
pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
|
||||
let params = def
|
||||
.params
|
||||
.iter()
|
||||
.map(|&v| Type::TypeVariable(v).into())
|
||||
.map(|&v| Rc::new(Type::TypeVariable(v)))
|
||||
.collect();
|
||||
self.sym_table.insert(
|
||||
def.base.name,
|
||||
@ -105,6 +114,14 @@ impl<'a> GlobalContext<'a> {
|
||||
self.var_defs.push(def);
|
||||
}
|
||||
|
||||
pub fn add_fn(&'a mut self, name: &'a str, def: FnDef) {
|
||||
self.fn_table.insert(name, def);
|
||||
}
|
||||
|
||||
pub fn get_fn(&self, name: &str) -> Option<&FnDef> {
|
||||
self.fn_table.get(name)
|
||||
}
|
||||
|
||||
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
|
||||
self.primitive_defs.get_mut(id.0).unwrap()
|
||||
}
|
||||
@ -144,9 +161,9 @@ impl<'a> GlobalContext<'a> {
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> Type {
|
||||
pub fn subst(&self, map: &HashMap<VariableId, Rc<Type>>) -> Type {
|
||||
match self {
|
||||
Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(),
|
||||
Type::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
|
||||
Type::ParametricType(id, params) => Type::ParametricType(
|
||||
*id,
|
||||
params
|
||||
@ -158,9 +175,9 @@ impl Type {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type {
|
||||
pub fn inv_subst(&self, map: &[(Rc<Type>, Rc<Type>)]) -> Rc<Type> {
|
||||
for (from, to) in map.iter() {
|
||||
if self == from {
|
||||
if self == from.as_ref() {
|
||||
return to.clone();
|
||||
}
|
||||
}
|
||||
@ -173,16 +190,16 @@ impl Type {
|
||||
.collect(),
|
||||
),
|
||||
_ => self.clone(),
|
||||
}
|
||||
}.into()
|
||||
}
|
||||
|
||||
pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Type> {
|
||||
pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Rc<Type>> {
|
||||
match self {
|
||||
Type::ParametricType(id, params) => {
|
||||
let vars = &ctx.get_parametric(*id).params;
|
||||
vars.iter()
|
||||
.zip(params)
|
||||
.map(|(v, p)| (*v, p.as_ref().clone()))
|
||||
.map(|(v, p)| (*v, p.as_ref().clone().into()))
|
||||
.collect()
|
||||
}
|
||||
// if this proves to be slow, we can use option type
|
||||
|
Loading…
Reference in New Issue
Block a user