implemented inference

rc nightmare...
This commit is contained in:
pca006132 2020-12-28 11:06:32 +08:00 committed by pca006132
parent fa02dc8271
commit fd3e1d4923
2 changed files with 101 additions and 20 deletions

View File

@ -4,7 +4,7 @@ use std::rc::Rc;
fn find_subst( fn find_subst(
ctx: &GlobalContext, ctx: &GlobalContext,
assumptions: & HashMap<VariableId, Rc<Type>>, assumptions: &HashMap<VariableId, Rc<Type>>,
sub: &mut HashMap<VariableId, Rc<Type>>, sub: &mut HashMap<VariableId, Rc<Type>>,
mut a: Rc<Type>, mut a: Rc<Type>,
mut b: Rc<Type>, mut b: Rc<Type>,
@ -46,7 +46,7 @@ fn find_subst(
} }
(TypeVariable(id_a), _) => { (TypeVariable(id_a), _) => {
let v_a = ctx.get_variable(*id_a); let v_a = ctx.get_variable(*id_a);
if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() { if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() {
Ok(()) Ok(())
} else { } else {
Err("different domain".to_string()) Err("different domain".to_string())
@ -83,7 +83,7 @@ fn find_subst(
parents.extend_from_slice(&c.parents); parents.extend_from_slice(&c.parents);
} }
Err("not subtype".to_string()) Err("not subtype".to_string())
}, }
(ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => {
if id_a != id_b || param_a.len() != param_b.len() { if id_a != id_b || param_a.len() != param_b.len() {
Err("different parametric types".to_string()) Err("different parametric types".to_string())
@ -93,7 +93,7 @@ fn find_subst(
} }
Ok(()) Ok(())
} }
}, }
(_, _) => { (_, _) => {
if a == b { if a == b {
Ok(()) Ok(())
@ -104,4 +104,68 @@ fn find_subst(
} }
} }
pub fn resolve_call(
ctx: &GlobalContext,
obj: Option<Rc<Type>>,
func: &str,
args: Rc<Type>,
assumptions: &mut HashMap<VariableId, Rc<Type>>,
) -> Result<Option<Rc<Type>>, String> {
let obj = obj.as_ref().map(|v| Rc::new(v.subst(assumptions)));
let mut subst = obj
.as_ref()
.map(|v| v.get_subst(ctx))
.unwrap_or(HashMap::new());
let fun = match &obj {
Some(obj) => {
let base = match obj.as_ref() {
TypeVariable(id) => {
let v = ctx.get_variable(*id);
if v.bound.len() == 0 {
return Err("unbounded type var".to_string());
}
let results: Result<Vec<_>, String> = v
.bound
.iter()
.map(|ins| {
assumptions.insert(*id, ins.clone());
resolve_call(ctx, Some(obj.clone()), func, args.clone(), assumptions)
})
.collect();
let results = results?;
if results.iter().all(|v| v == &results[0]) {
return Ok(results[0].clone());
}
let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| {
r.as_ref()
.map(|v| v.inv_subst(&[(ins.clone(), obj.clone().into())]))
});
let first = results.next().unwrap();
if results.all(|v| v == first) {
return Ok(first);
} else {
return Err("divergent type after substitution".to_string());
}
},
PrimitiveType(id) => &ctx.get_primitive(*id),
ClassType(id) | VirtualClassType(id) => &ctx.get_class(*id).base,
ParametricType(id, _) => &ctx.get_parametric(*id).base,
_ => return Err("not supported".to_string()),
};
base.methods.get(func)
},
None => ctx.get_fn(func),
}
.ok_or("no such function".to_string())?;
find_subst(ctx, assumptions, &mut subst, args, fun.args.clone())?;
let result = fun.result.as_ref().map(|v| v.subst(&subst));
Ok(result.map(|result| {
if let SelfType = result {
obj.unwrap()
} else {
result.into()
}
}))
}

View File

@ -1,19 +1,19 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
#[derive(PartialEq, Eq, Copy, Clone, Hash)] #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct PrimitiveId(usize); pub struct PrimitiveId(usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash)] #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ClassId(usize); pub struct ClassId(usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash)] #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct ParamId(usize); pub struct ParamId(usize);
#[derive(PartialEq, Eq, Copy, Clone, Hash)] #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
pub struct VariableId(usize); pub struct VariableId(usize);
#[derive(PartialEq, Eq, Clone, Hash)] #[derive(PartialEq, Eq, Clone, Hash, Debug)]
pub enum Type { pub enum Type {
BotType, BotType,
SelfType, SelfType,
@ -25,8 +25,8 @@ pub enum Type {
} }
pub struct FnDef { pub struct FnDef {
pub args: Vec<Type>, pub args: Rc<Type>,
pub result: Option<Type>, pub result: Option<Rc<Type>>,
} }
pub struct TypeDef<'a> { pub struct TypeDef<'a> {
@ -47,15 +47,22 @@ pub struct ParametricDef<'a> {
pub struct VarDef<'a> { pub struct VarDef<'a> {
pub name: &'a str, pub name: &'a str,
pub bound: Vec<Type>, pub bound: Vec<Rc<Type>>,
} }
pub const TUPLE_TYPE: ParamId = ParamId(0);
pub const LIST_TYPE: ParamId = ParamId(1);
pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0);
pub const INT32_TYPE: PrimitiveId = PrimitiveId(1);
pub struct GlobalContext<'a> { pub struct GlobalContext<'a> {
primitive_defs: Vec<TypeDef<'a>>, primitive_defs: Vec<TypeDef<'a>>,
class_defs: Vec<ClassDef<'a>>, class_defs: Vec<ClassDef<'a>>,
parametric_defs: Vec<ParametricDef<'a>>, parametric_defs: Vec<ParametricDef<'a>>,
var_defs: Vec<VarDef<'a>>, var_defs: Vec<VarDef<'a>>,
sym_table: HashMap<&'a str, Type>, sym_table: HashMap<&'a str, Type>,
fn_table: HashMap<&'a str, FnDef>,
} }
impl<'a> GlobalContext<'a> { impl<'a> GlobalContext<'a> {
@ -69,6 +76,7 @@ impl<'a> GlobalContext<'a> {
class_defs: Vec::new(), class_defs: Vec::new(),
parametric_defs: Vec::new(), parametric_defs: Vec::new(),
var_defs: Vec::new(), var_defs: Vec::new(),
fn_table: HashMap::new(),
sym_table, sym_table,
}; };
} }
@ -80,11 +88,12 @@ impl<'a> GlobalContext<'a> {
); );
self.class_defs.push(def); self.class_defs.push(def);
} }
pub fn add_parametric(&mut self, def: ParametricDef<'a>) { pub fn add_parametric(&mut self, def: ParametricDef<'a>) {
let params = def let params = def
.params .params
.iter() .iter()
.map(|&v| Type::TypeVariable(v).into()) .map(|&v| Rc::new(Type::TypeVariable(v)))
.collect(); .collect();
self.sym_table.insert( self.sym_table.insert(
def.base.name, def.base.name,
@ -105,6 +114,14 @@ impl<'a> GlobalContext<'a> {
self.var_defs.push(def); self.var_defs.push(def);
} }
pub fn add_fn(&'a mut self, name: &'a str, def: FnDef) {
self.fn_table.insert(name, def);
}
pub fn get_fn(&self, name: &str) -> Option<&FnDef> {
self.fn_table.get(name)
}
pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> {
self.primitive_defs.get_mut(id.0).unwrap() self.primitive_defs.get_mut(id.0).unwrap()
} }
@ -144,9 +161,9 @@ impl<'a> GlobalContext<'a> {
} }
impl Type { impl Type {
pub fn subst(&self, map: &HashMap<VariableId, Type>) -> Type { pub fn subst(&self, map: &HashMap<VariableId, Rc<Type>>) -> Type {
match self { match self {
Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(), Type::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(),
Type::ParametricType(id, params) => Type::ParametricType( Type::ParametricType(id, params) => Type::ParametricType(
*id, *id,
params params
@ -158,9 +175,9 @@ impl Type {
} }
} }
pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type { pub fn inv_subst(&self, map: &[(Rc<Type>, Rc<Type>)]) -> Rc<Type> {
for (from, to) in map.iter() { for (from, to) in map.iter() {
if self == from { if self == from.as_ref() {
return to.clone(); return to.clone();
} }
} }
@ -173,16 +190,16 @@ impl Type {
.collect(), .collect(),
), ),
_ => self.clone(), _ => self.clone(),
} }.into()
} }
pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Type> { pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap<VariableId, Rc<Type>> {
match self { match self {
Type::ParametricType(id, params) => { Type::ParametricType(id, params) => {
let vars = &ctx.get_parametric(*id).params; let vars = &ctx.get_parametric(*id).params;
vars.iter() vars.iter()
.zip(params) .zip(params)
.map(|(v, p)| (*v, p.as_ref().clone())) .map(|(v, p)| (*v, p.as_ref().clone().into()))
.collect() .collect()
} }
// if this proves to be slow, we can use option type // if this proves to be slow, we can use option type