forked from M-Labs/nac3
Merge remote-tracking branch 'origin/hm-inference' into hm-inference_anto
This commit is contained in:
commit
ae79533cfd
|
@ -1,6 +1,6 @@
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::convert::TryInto;
|
use std::convert::{TryInto, From};
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
@ -17,6 +17,21 @@ use rustpython_parser::ast::{
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
|
||||||
|
pub struct CodeLocation {
|
||||||
|
row: usize,
|
||||||
|
col: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Location> for CodeLocation {
|
||||||
|
fn from(loc: Location) -> CodeLocation {
|
||||||
|
CodeLocation {
|
||||||
|
row: loc.row(),
|
||||||
|
col: loc.column()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct PrimitiveStore {
|
pub struct PrimitiveStore {
|
||||||
pub int32: Type,
|
pub int32: Type,
|
||||||
pub int64: Type,
|
pub int64: Type,
|
||||||
|
@ -37,6 +52,7 @@ pub struct Inferencer<'a> {
|
||||||
pub primitives: &'a PrimitiveStore,
|
pub primitives: &'a PrimitiveStore,
|
||||||
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
pub virtual_checks: &'a mut Vec<(Type, Type)>,
|
||||||
pub variable_mapping: HashMap<String, Type>,
|
pub variable_mapping: HashMap<String, Type>,
|
||||||
|
pub calls: &'a mut HashMap<CodeLocation, Rc<Call>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NaiveFolder();
|
struct NaiveFolder();
|
||||||
|
@ -215,6 +231,7 @@ impl<'a> Inferencer<'a> {
|
||||||
unifier: self.unifier,
|
unifier: self.unifier,
|
||||||
primitives: self.primitives,
|
primitives: self.primitives,
|
||||||
virtual_checks: self.virtual_checks,
|
virtual_checks: self.virtual_checks,
|
||||||
|
calls: self.calls,
|
||||||
variable_mapping,
|
variable_mapping,
|
||||||
};
|
};
|
||||||
let fun = FunSignature {
|
let fun = FunSignature {
|
||||||
|
@ -257,6 +274,7 @@ impl<'a> Inferencer<'a> {
|
||||||
virtual_checks: self.virtual_checks,
|
virtual_checks: self.virtual_checks,
|
||||||
variable_mapping,
|
variable_mapping,
|
||||||
primitives: self.primitives,
|
primitives: self.primitives,
|
||||||
|
calls: self.calls,
|
||||||
};
|
};
|
||||||
let elt = new_context.fold_expr(elt)?;
|
let elt = new_context.fold_expr(elt)?;
|
||||||
let generator = generators.pop().unwrap();
|
let generator = generators.pop().unwrap();
|
||||||
|
@ -379,6 +397,7 @@ impl<'a> Inferencer<'a> {
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
ret,
|
ret,
|
||||||
});
|
});
|
||||||
|
self.calls.insert(location.into(), call.clone());
|
||||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into()));
|
||||||
self.unifier.unify(func.custom.unwrap(), call)?;
|
self.unifier.unify(func.custom.unwrap(), call)?;
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ struct TestEnvironment {
|
||||||
pub id_to_name: HashMap<usize, String>,
|
pub id_to_name: HashMap<usize, String>,
|
||||||
pub identifier_mapping: HashMap<String, Type>,
|
pub identifier_mapping: HashMap<String, Type>,
|
||||||
pub virtual_checks: Vec<(Type, Type)>,
|
pub virtual_checks: Vec<(Type, Type)>,
|
||||||
|
pub calls: HashMap<CodeLocation, Rc<Call>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestEnvironment {
|
impl TestEnvironment {
|
||||||
|
@ -218,12 +219,13 @@ impl TestEnvironment {
|
||||||
function_data: FunctionData {
|
function_data: FunctionData {
|
||||||
resolver,
|
resolver,
|
||||||
bound_variables: Vec::new(),
|
bound_variables: Vec::new(),
|
||||||
return_type: None
|
return_type: None,
|
||||||
},
|
},
|
||||||
primitives,
|
primitives,
|
||||||
id_to_name,
|
id_to_name,
|
||||||
identifier_mapping,
|
identifier_mapping,
|
||||||
virtual_checks: Vec::new(),
|
virtual_checks: Vec::new(),
|
||||||
|
calls: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,6 +236,7 @@ impl TestEnvironment {
|
||||||
variable_mapping: Default::default(),
|
variable_mapping: Default::default(),
|
||||||
primitives: &mut self.primitives,
|
primitives: &mut self.primitives,
|
||||||
virtual_checks: &mut self.virtual_checks,
|
virtual_checks: &mut self.virtual_checks,
|
||||||
|
calls: &mut self.calls,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,9 @@ pub enum TypeVarMeta {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum TypeEnum {
|
pub enum TypeEnum {
|
||||||
|
TRigidVar {
|
||||||
|
id: u32,
|
||||||
|
},
|
||||||
TVar {
|
TVar {
|
||||||
id: u32,
|
id: u32,
|
||||||
meta: TypeVarMeta,
|
meta: TypeVarMeta,
|
||||||
|
@ -74,6 +77,7 @@ pub enum TypeEnum {
|
||||||
impl TypeEnum {
|
impl TypeEnum {
|
||||||
pub fn get_type_name(&self) -> &'static str {
|
pub fn get_type_name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
|
TypeEnum::TRigidVar { .. } => "TRigidVar",
|
||||||
TypeEnum::TVar { .. } => "TVar",
|
TypeEnum::TVar { .. } => "TVar",
|
||||||
TypeEnum::TTuple { .. } => "TTuple",
|
TypeEnum::TTuple { .. } => "TTuple",
|
||||||
TypeEnum::TList { .. } => "TList",
|
TypeEnum::TList { .. } => "TList",
|
||||||
|
@ -127,6 +131,12 @@ impl Unifier {
|
||||||
self.unification_table.probe_value(a).clone()
|
self.unification_table.probe_value(a).clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) {
|
||||||
|
let id = self.var_id + 1;
|
||||||
|
self.var_id += 1;
|
||||||
|
(self.add_ty(TypeEnum::TRigidVar { id }), id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_fresh_var(&mut self) -> (Type, u32) {
|
pub fn get_fresh_var(&mut self) -> (Type, u32) {
|
||||||
self.get_fresh_var_with_range(&[])
|
self.get_fresh_var_with_range(&[])
|
||||||
}
|
}
|
||||||
|
@ -139,9 +149,17 @@ impl Unifier {
|
||||||
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
|
(self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unification would not unify rigid variables with other types, but we want to do this for
|
||||||
|
/// function instantiations, so we make it explicit.
|
||||||
|
pub fn replace_rigid_var(&mut self, rigid: Type, b: Type) {
|
||||||
|
assert!(matches!(&*self.get_ty(rigid), TypeEnum::TRigidVar { .. }));
|
||||||
|
self.set_a_to_b(rigid, b);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||||
use TypeEnum::*;
|
use TypeEnum::*;
|
||||||
match &*self.get_ty(a) {
|
match &*self.get_ty(a) {
|
||||||
|
TRigidVar { .. } => true,
|
||||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
TCall { .. } => false,
|
TCall { .. } => false,
|
||||||
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
TList { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
@ -290,11 +308,8 @@ impl Unifier {
|
||||||
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
||||||
self.occur_check(a, b)?;
|
self.occur_check(a, b)?;
|
||||||
for (k, v) in map.borrow().iter() {
|
for (k, v) in map.borrow().iter() {
|
||||||
if let Some(ty) = fields.get(k) {
|
let ty = fields.get(k).ok_or_else(|| format!("No such attribute {}", k))?;
|
||||||
self.unify(*ty, *v)?;
|
self.unify(*ty, *v)?;
|
||||||
} else {
|
|
||||||
return Err(format!("No such attribute {}", k));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||||
self.unify(x, b)?;
|
self.unify(x, b)?;
|
||||||
|
@ -305,14 +320,11 @@ impl Unifier {
|
||||||
let ty = self.get_ty(*ty);
|
let ty = self.get_ty(*ty);
|
||||||
if let TObj { fields, .. } = ty.as_ref() {
|
if let TObj { fields, .. } = ty.as_ref() {
|
||||||
for (k, v) in map.borrow().iter() {
|
for (k, v) in map.borrow().iter() {
|
||||||
if let Some(ty) = fields.get(k) {
|
let ty = fields.get(k).ok_or_else(|| format!("No such attribute {}", k))?;
|
||||||
if !matches!(self.get_ty(*ty).as_ref(), TFunc { .. }) {
|
if !matches!(self.get_ty(*ty).as_ref(), TFunc { .. }) {
|
||||||
return Err(format!("Cannot access field {} for virtual type", k));
|
return Err(format!("Cannot access field {} for virtual type", k));
|
||||||
}
|
|
||||||
self.unify(*v, *ty)?;
|
|
||||||
} else {
|
|
||||||
return Err(format!("No such attribute {}", k));
|
|
||||||
}
|
}
|
||||||
|
self.unify(*v, *ty)?;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// require annotation...
|
// require annotation...
|
||||||
|
@ -382,11 +394,11 @@ impl Unifier {
|
||||||
if let Some(i) = required.iter().position(|v| v == k) {
|
if let Some(i) = required.iter().position(|v| v == k) {
|
||||||
required.remove(i);
|
required.remove(i);
|
||||||
}
|
}
|
||||||
if let Some(i) = all_names.iter().position(|v| &v.0 == k) {
|
let i = all_names
|
||||||
self.unify(all_names.remove(i).1, *t)?;
|
.iter()
|
||||||
} else {
|
.position(|v| &v.0 == k)
|
||||||
return Err(format!("Unknown keyword argument {}", k));
|
.ok_or_else(|| format!("Unknown keyword argument {}", k))?;
|
||||||
}
|
self.unify(all_names.remove(i).1, *t)?;
|
||||||
}
|
}
|
||||||
if !required.is_empty() {
|
if !required.is_empty() {
|
||||||
return Err("Expected more arguments".to_string());
|
return Err("Expected more arguments".to_string());
|
||||||
|
@ -435,6 +447,7 @@ impl Unifier {
|
||||||
use TypeVarMeta::*;
|
use TypeVarMeta::*;
|
||||||
let ty = self.unification_table.probe_value(ty).clone();
|
let ty = self.unification_table.probe_value(ty).clone();
|
||||||
match ty.as_ref() {
|
match ty.as_ref() {
|
||||||
|
TypeEnum::TRigidVar { id } => var_to_name(*id),
|
||||||
TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id),
|
TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id),
|
||||||
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
||||||
let fields = map
|
let fields = map
|
||||||
|
@ -544,6 +557,7 @@ impl Unifier {
|
||||||
// variables, i.e. things like TRecord, TCall should not occur, and we
|
// variables, i.e. things like TRecord, TCall should not occur, and we
|
||||||
// should be safe to not implement the substitution for those variants.
|
// should be safe to not implement the substitution for those variants.
|
||||||
match &*ty {
|
match &*ty {
|
||||||
|
TypeEnum::TRigidVar { .. } => None,
|
||||||
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
|
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
|
||||||
TypeEnum::TTuple { ty } => {
|
TypeEnum::TTuple { ty } => {
|
||||||
let mut new_ty = Cow::from(ty);
|
let mut new_ty = Cow::from(ty);
|
||||||
|
@ -634,7 +648,7 @@ impl Unifier {
|
||||||
let ty = self.unification_table.probe_value(b).clone();
|
let ty = self.unification_table.probe_value(b).clone();
|
||||||
|
|
||||||
match ty.as_ref() {
|
match ty.as_ref() {
|
||||||
TypeEnum::TVar { meta: Generic, .. } => {}
|
TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {}
|
||||||
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
||||||
for t in map.borrow().values() {
|
for t in map.borrow().values() {
|
||||||
self.occur_check(a, *t)?;
|
self.occur_check(a, *t)?;
|
||||||
|
|
|
@ -419,22 +419,22 @@ fn test_typevar_range() {
|
||||||
|
|
||||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a});
|
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
||||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b});
|
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||||
let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0;
|
let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0;
|
||||||
env.unifier.unify(a_list, b_list).unwrap();
|
env.unifier.unify(a_list, b_list).unwrap();
|
||||||
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float});
|
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
|
||||||
env.unifier.unify(a_list, float_list).unwrap();
|
env.unifier.unify(a_list, float_list).unwrap();
|
||||||
// previous unifications should not affect a and b
|
// previous unifications should not affect a and b
|
||||||
env.unifier.unify(a, int).unwrap();
|
env.unifier.unify(a, int).unwrap();
|
||||||
|
|
||||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0;
|
||||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a});
|
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b});
|
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||||
env.unifier.unify(a_list, b_list).unwrap();
|
env.unifier.unify(a_list, b_list).unwrap();
|
||||||
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int});
|
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
env.unifier.unify(a_list, int_list),
|
env.unifier.unify(a_list, int_list),
|
||||||
Err("Cannot unify type variable 19 with TObj due to incompatible value range".into())
|
Err("Cannot unify type variable 19 with TObj due to incompatible value range".into())
|
||||||
|
@ -442,12 +442,34 @@ fn test_typevar_range() {
|
||||||
|
|
||||||
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
|
||||||
let b = env.unifier.get_fresh_var().0;
|
let b = env.unifier.get_fresh_var().0;
|
||||||
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a});
|
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
|
||||||
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b});
|
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
|
||||||
env.unifier.unify(a_list, b_list).unwrap();
|
env.unifier.unify(a_list, b_list).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
env.unifier.unify(b, boolean),
|
env.unifier.unify(b, boolean),
|
||||||
Err("Cannot unify type variable 21 with TObj due to incompatible value range".into())
|
Err("Cannot unify type variable 21 with TObj due to incompatible value range".into())
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rigid_var() {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let a = env.unifier.get_fresh_rigid_var().0;
|
||||||
|
let b = env.unifier.get_fresh_rigid_var().0;
|
||||||
|
let x = env.unifier.get_fresh_var().0;
|
||||||
|
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
|
||||||
|
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
|
||||||
|
let int = env.parse("int", &HashMap::new());
|
||||||
|
let list_int = env.parse("List[int]", &HashMap::new());
|
||||||
|
|
||||||
|
assert_eq!(env.unifier.unify(a, b), Err("Cannot unify TRigidVar with TRigidVar".to_string()));
|
||||||
|
env.unifier.unify(list_a, list_x).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
env.unifier.unify(list_x, list_int),
|
||||||
|
Err("Cannot unify TObj with TRigidVar".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
env.unifier.replace_rigid_var(a, int);
|
||||||
|
env.unifier.unify(list_x, list_int).unwrap();
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue