From fd3e1d49239d39ab4afca0c2543bafd2ea01dfdc Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 28 Dec 2020 11:06:32 +0800 Subject: [PATCH] implemented inference rc nightmare... --- nac3type/src/inference.rs | 72 ++++++++++++++++++++++++++++++++++++--- nac3type/src/types.rs | 49 +++++++++++++++++--------- 2 files changed, 101 insertions(+), 20 deletions(-) diff --git a/nac3type/src/inference.rs b/nac3type/src/inference.rs index 9432938184..685c07a6a9 100644 --- a/nac3type/src/inference.rs +++ b/nac3type/src/inference.rs @@ -4,7 +4,7 @@ use std::rc::Rc; fn find_subst( ctx: &GlobalContext, - assumptions: & HashMap>, + assumptions: &HashMap>, sub: &mut HashMap>, mut a: Rc, mut b: Rc, @@ -46,7 +46,7 @@ fn find_subst( } (TypeVariable(id_a), _) => { let v_a = ctx.get_variable(*id_a); - if v_a.bound.len() == 1 && &v_a.bound[0] == b.as_ref() { + if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() { Ok(()) } else { Err("different domain".to_string()) @@ -83,7 +83,7 @@ fn find_subst( parents.extend_from_slice(&c.parents); } Err("not subtype".to_string()) - }, + } (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { if id_a != id_b || param_a.len() != param_b.len() { Err("different parametric types".to_string()) @@ -93,7 +93,7 @@ fn find_subst( } Ok(()) } - }, + } (_, _) => { if a == b { Ok(()) @@ -104,4 +104,68 @@ fn find_subst( } } +pub fn resolve_call( + ctx: &GlobalContext, + obj: Option>, + func: &str, + args: Rc, + assumptions: &mut HashMap>, +) -> Result>, String> { + let obj = obj.as_ref().map(|v| Rc::new(v.subst(assumptions))); + let mut subst = obj + .as_ref() + .map(|v| v.get_subst(ctx)) + .unwrap_or(HashMap::new()); + let fun = match &obj { + Some(obj) => { + let base = match obj.as_ref() { + TypeVariable(id) => { + let v = ctx.get_variable(*id); + if v.bound.len() == 0 { + return Err("unbounded type var".to_string()); + } + let results: Result, String> = v + .bound + .iter() + .map(|ins| { + assumptions.insert(*id, ins.clone()); + resolve_call(ctx, Some(obj.clone()), func, args.clone(), assumptions) + }) + .collect(); + let results = results?; + if results.iter().all(|v| v == &results[0]) { + return Ok(results[0].clone()); + } + let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| { + r.as_ref() + .map(|v| v.inv_subst(&[(ins.clone(), obj.clone().into())])) + }); + let first = results.next().unwrap(); + if results.all(|v| v == first) { + return Ok(first); + } else { + return Err("divergent type after substitution".to_string()); + } + }, + PrimitiveType(id) => &ctx.get_primitive(*id), + ClassType(id) | VirtualClassType(id) => &ctx.get_class(*id).base, + ParametricType(id, _) => &ctx.get_parametric(*id).base, + _ => return Err("not supported".to_string()), + }; + base.methods.get(func) + }, + None => ctx.get_fn(func), + } + .ok_or("no such function".to_string())?; + + find_subst(ctx, assumptions, &mut subst, args, fun.args.clone())?; + let result = fun.result.as_ref().map(|v| v.subst(&subst)); + Ok(result.map(|result| { + if let SelfType = result { + obj.unwrap() + } else { + result.into() + } + })) +} diff --git a/nac3type/src/types.rs b/nac3type/src/types.rs index adbfc70b37..ad11c36fc8 100644 --- a/nac3type/src/types.rs +++ b/nac3type/src/types.rs @@ -1,19 +1,19 @@ use std::collections::HashMap; use std::rc::Rc; -#[derive(PartialEq, Eq, Copy, Clone, Hash)] +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] pub struct PrimitiveId(usize); -#[derive(PartialEq, Eq, Copy, Clone, Hash)] +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] pub struct ClassId(usize); -#[derive(PartialEq, Eq, Copy, Clone, Hash)] +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] pub struct ParamId(usize); -#[derive(PartialEq, Eq, Copy, Clone, Hash)] +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] pub struct VariableId(usize); -#[derive(PartialEq, Eq, Clone, Hash)] +#[derive(PartialEq, Eq, Clone, Hash, Debug)] pub enum Type { BotType, SelfType, @@ -25,8 +25,8 @@ pub enum Type { } pub struct FnDef { - pub args: Vec, - pub result: Option, + pub args: Rc, + pub result: Option>, } pub struct TypeDef<'a> { @@ -47,15 +47,22 @@ pub struct ParametricDef<'a> { pub struct VarDef<'a> { pub name: &'a str, - pub bound: Vec, + pub bound: Vec>, } +pub const TUPLE_TYPE: ParamId = ParamId(0); +pub const LIST_TYPE: ParamId = ParamId(1); + +pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0); +pub const INT32_TYPE: PrimitiveId = PrimitiveId(1); + pub struct GlobalContext<'a> { primitive_defs: Vec>, class_defs: Vec>, parametric_defs: Vec>, var_defs: Vec>, sym_table: HashMap<&'a str, Type>, + fn_table: HashMap<&'a str, FnDef>, } impl<'a> GlobalContext<'a> { @@ -69,6 +76,7 @@ impl<'a> GlobalContext<'a> { class_defs: Vec::new(), parametric_defs: Vec::new(), var_defs: Vec::new(), + fn_table: HashMap::new(), sym_table, }; } @@ -80,11 +88,12 @@ impl<'a> GlobalContext<'a> { ); self.class_defs.push(def); } + pub fn add_parametric(&mut self, def: ParametricDef<'a>) { let params = def .params .iter() - .map(|&v| Type::TypeVariable(v).into()) + .map(|&v| Rc::new(Type::TypeVariable(v))) .collect(); self.sym_table.insert( def.base.name, @@ -105,6 +114,14 @@ impl<'a> GlobalContext<'a> { self.var_defs.push(def); } + pub fn add_fn(&'a mut self, name: &'a str, def: FnDef) { + self.fn_table.insert(name, def); + } + + pub fn get_fn(&self, name: &str) -> Option<&FnDef> { + self.fn_table.get(name) + } + pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { self.primitive_defs.get_mut(id.0).unwrap() } @@ -144,9 +161,9 @@ impl<'a> GlobalContext<'a> { } impl Type { - pub fn subst(&self, map: &HashMap) -> Type { + pub fn subst(&self, map: &HashMap>) -> Type { match self { - Type::TypeVariable(id) => map.get(id).unwrap_or(self).clone(), + Type::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(), Type::ParametricType(id, params) => Type::ParametricType( *id, params @@ -158,9 +175,9 @@ impl Type { } } - pub fn inv_subst(&self, map: &Vec<(Type, Type)>) -> Type { + pub fn inv_subst(&self, map: &[(Rc, Rc)]) -> Rc { for (from, to) in map.iter() { - if self == from { + if self == from.as_ref() { return to.clone(); } } @@ -173,16 +190,16 @@ impl Type { .collect(), ), _ => self.clone(), - } + }.into() } - pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap { + pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap> { match self { Type::ParametricType(id, params) => { let vars = &ctx.get_parametric(*id).params; vars.iter() .zip(params) - .map(|(v, p)| (*v, p.as_ref().clone())) + .map(|(v, p)| (*v, p.as_ref().clone().into())) .collect() } // if this proves to be slow, we can use option type