From 9cba777b3bdebfa90733d9a7324a6610ec6f77b4 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 31 Dec 2020 11:28:13 +0800 Subject: [PATCH] added inference --- nac3core/src/inference.rs | 199 ++++++++++++++++++++++++++++++++++++++ nac3core/src/lib.rs | 1 + 2 files changed, 200 insertions(+) create mode 100644 nac3core/src/inference.rs diff --git a/nac3core/src/inference.rs b/nac3core/src/inference.rs new file mode 100644 index 00000000..a0820f5c --- /dev/null +++ b/nac3core/src/inference.rs @@ -0,0 +1,199 @@ +use super::typedef::{Type::*, *}; +use std::collections::HashMap; +use std::rc::Rc; + +fn find_subst( + ctx: &GlobalContext, + valuation: &Option<(VariableId, Rc)>, + sub: &mut HashMap>, + mut a: Rc, + mut b: Rc, +) -> Result<(), String> { + // TODO: fix error messages later + if let TypeVariable(id) = a.as_ref() { + if let Some((assumption_id, t)) = valuation { + if assumption_id == id { + a = t.clone(); + } + } + } + + let mut substituted = false; + if let TypeVariable(id) = b.as_ref() { + if let Some(c) = sub.get(&id) { + b = c.clone(); + substituted = true; + } + } + + match (a.as_ref(), b.as_ref()) { + (BotType, _) => Ok(()), + (TypeVariable(id_a), TypeVariable(id_b)) => { + if substituted { + return if id_a == id_b { + Ok(()) + } else { + Err("different variables".to_string()) + }; + } + let v_a = ctx.get_variable(*id_a); + let v_b = ctx.get_variable(*id_b); + if v_b.bound.len() > 0 { + if v_a.bound.len() == 0 { + return Err("unbounded a".to_string()); + } else { + let diff: Vec<_> = v_a + .bound + .iter() + .filter(|x| !v_b.bound.contains(x)) + .collect(); + if diff.len() > 0 { + return Err("different domain".to_string()); + } + } + } + sub.insert(*id_b, a.clone().into()); + Ok(()) + } + (TypeVariable(id_a), _) => { + let v_a = ctx.get_variable(*id_a); + if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() { + Ok(()) + } else { + Err("different domain".to_string()) + } + } + (_, TypeVariable(id_b)) => { + let v_b = ctx.get_variable(*id_b); + if v_b.bound.len() == 0 || v_b.bound.contains(&a) { + sub.insert(*id_b, a.clone().into()); + Ok(()) + } else { + Err("different domain".to_string()) + } + } + (_, VirtualClassType(id_b)) => { + let mut parents; + match a.as_ref() { + ClassType(id_a) => { + parents = [*id_a].to_vec(); + } + VirtualClassType(id_a) => { + parents = [*id_a].to_vec(); + } + _ => { + return Err("cannot substitute non-class type into virtual class".to_string()); + } + }; + while !parents.is_empty() { + if *id_b == parents[0] { + return Ok(()); + } + let c = ctx.get_class(parents.remove(0)); + parents.extend_from_slice(&c.parents); + } + Err("not subtype".to_string()) + } + (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { + if id_a != id_b || param_a.len() != param_b.len() { + Err("different parametric types".to_string()) + } else { + for (x, y) in param_a.iter().zip(param_b.iter()) { + find_subst(ctx, valuation, sub, x.clone(), y.clone())?; + } + Ok(()) + } + } + (_, _) => { + if a == b { + Ok(()) + } else { + Err("not equal".to_string()) + } + } + } +} + +fn resolve_call_rec( + ctx: &GlobalContext, + valuation: &Option<(VariableId, Rc)>, + obj: Option>, + func: &str, + args: &[Rc], +) -> Result>, String> { + let mut subst = obj + .as_ref() + .map(|v| v.get_subst(ctx)) + .unwrap_or(HashMap::new()); + + let fun = match &obj { + Some(obj) => { + let base = match obj.as_ref() { + TypeVariable(id) => { + let v = ctx.get_variable(*id); + if v.bound.len() == 0 { + return Err("unbounded type var".to_string()); + } + let results: Result, String> = v + .bound + .iter() + .map(|ins| { + resolve_call_rec( + ctx, + &Some((*id, ins.clone())), + Some(ins.clone()), + func, + args.clone(), + ) + }) + .collect(); + let results = results?; + if results.iter().all(|v| v == &results[0]) { + return Ok(results[0].clone()); + } + let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| { + r.as_ref() + .map(|v| v.inv_subst(&[(ins.clone(), obj.clone().into())])) + }); + let first = results.next().unwrap(); + if results.all(|v| v == first) { + return Ok(first); + } else { + return Err("divergent type after substitution".to_string()); + } + } + PrimitiveType(id) => &ctx.get_primitive(*id), + ClassType(id) | VirtualClassType(id) => &ctx.get_class(*id).base, + ParametricType(id, _) => &ctx.get_parametric(*id).base, + _ => return Err("not supported".to_string()), + }; + base.methods.get(func) + } + None => ctx.get_fn(func), + } + .ok_or("no such function".to_string())?; + + if args.len() != fun.args.len() { + return Err("incorrect parameter number".to_string()); + } + for (a, b) in args.iter().zip(fun.args.iter()) { + find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?; + } + let result = fun.result.as_ref().map(|v| v.subst(&subst)); + Ok(result.map(|result| { + if let SelfType = result { + obj.unwrap() + } else { + result.into() + } + })) +} + +pub fn resolve_call( + ctx: &GlobalContext, + obj: Option>, + func: &str, + args: &[Rc], +) -> Result>, String> { + resolve_call_rec(ctx, &None, obj, func, args) +} diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 801fcc46..761cc82e 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -2,6 +2,7 @@ extern crate num_bigint; extern crate inkwell; extern crate rustpython_parser; +pub mod inference; pub mod primitives; pub mod typedef;