forked from M-Labs/nac3
nac3core/typecheck: supports recursive type inference
This commit is contained in:
parent
fd0b11087e
commit
26076c37ba
@ -443,19 +443,19 @@ impl<'a> Inferencer<'a> {
|
||||
return Err("Async iterator not supported.".to_string());
|
||||
}
|
||||
new_context.infer_pattern(&generator.target)?;
|
||||
let elt = new_context.fold_expr(elt)?;
|
||||
let target = new_context.fold_expr(*generator.target)?;
|
||||
let iter = new_context.fold_expr(*generator.iter)?;
|
||||
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
|
||||
let ifs: Vec<_> = generator
|
||||
.ifs
|
||||
.into_iter()
|
||||
.map(|v| new_context.fold_expr(v))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let elt = new_context.fold_expr(elt)?;
|
||||
// iter should be a list of targets...
|
||||
// actually it should be an iterator of targets, but we don't have iter type for now
|
||||
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
|
||||
// if conditions should be bool
|
||||
for v in ifs.iter() {
|
||||
new_context.unify(v.custom.unwrap(), new_context.primitives.bool, &v.location)?;
|
||||
|
@ -19,7 +19,13 @@ struct Resolver {
|
||||
}
|
||||
|
||||
impl SymbolResolver for Resolver {
|
||||
fn get_symbol_type(&self, _: &mut Unifier, _: &[Arc<RwLock<TopLevelDef>>], _: &PrimitiveStore, str: StrRef) -> Option<Type> {
|
||||
fn get_symbol_type(
|
||||
&self,
|
||||
_: &mut Unifier,
|
||||
_: &[Arc<RwLock<TopLevelDef>>],
|
||||
_: &PrimitiveStore,
|
||||
str: StrRef,
|
||||
) -> Option<Type> {
|
||||
self.id_to_type.get(&str).cloned()
|
||||
}
|
||||
|
||||
@ -60,6 +66,14 @@ impl TestEnvironment {
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) {
|
||||
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
|
||||
ret: int32,
|
||||
vars: HashMap::new()
|
||||
}.into()));
|
||||
fields.borrow_mut().insert("__add__".into(), add_ty);
|
||||
}
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
@ -132,6 +146,14 @@ impl TestEnvironment {
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new().into(),
|
||||
});
|
||||
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) {
|
||||
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
|
||||
ret: int32,
|
||||
vars: HashMap::new()
|
||||
}.into()));
|
||||
fields.borrow_mut().insert("__add__".into(), add_ty);
|
||||
}
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
@ -347,6 +369,15 @@ impl TestEnvironment {
|
||||
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "lambda test")]
|
||||
#[test_case(indoc! {"
|
||||
a = lambda x: x + x
|
||||
b = lambda x: a(x) + x
|
||||
a = b
|
||||
c = b(1)
|
||||
"},
|
||||
[("a", "fn[[x=int32], int32]"), ("b", "fn[[x=int32], int32]"), ("c", "int32")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "lambda test 2")]
|
||||
#[test_case(indoc! {"
|
||||
a = lambda x: x
|
||||
b = lambda x: x
|
||||
@ -365,11 +396,10 @@ impl TestEnvironment {
|
||||
&[]
|
||||
; "obj test")]
|
||||
#[test_case(indoc! {"
|
||||
f = lambda x: True
|
||||
a = [1, 2, 3]
|
||||
b = [f(x) for x in a if f(x)]
|
||||
b = [x + x for x in a]
|
||||
"},
|
||||
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
|
||||
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
|
||||
&[]
|
||||
; "listcomp test")]
|
||||
#[test_case(indoc! {"
|
||||
|
@ -1,10 +1,9 @@
|
||||
use itertools::{chain, zip, Itertools};
|
||||
use std::borrow::Cow;
|
||||
use itertools::{zip, Itertools};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::iter::once;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::{borrow::Cow, collections::HashSet};
|
||||
|
||||
use rustpython_parser::ast::StrRef;
|
||||
|
||||
@ -105,6 +104,7 @@ pub struct Unifier {
|
||||
unification_table: UnificationTable<Rc<TypeEnum>>,
|
||||
calls: Vec<Rc<Call>>,
|
||||
var_id: u32,
|
||||
unify_cache: HashSet<(Type, Type)>,
|
||||
}
|
||||
|
||||
impl Default for Unifier {
|
||||
@ -120,6 +120,7 @@ impl Unifier {
|
||||
unification_table: UnificationTable::new(),
|
||||
var_id: 0,
|
||||
calls: Vec::new(),
|
||||
unify_cache: HashSet::new(),
|
||||
top_level: None,
|
||||
}
|
||||
}
|
||||
@ -148,9 +149,7 @@ impl Unifier {
|
||||
fields
|
||||
.borrow()
|
||||
.iter()
|
||||
.map(|(name, ty)| {
|
||||
(*name, self.copy_from(unifier, *ty, type_cache))
|
||||
})
|
||||
.map(|(name, ty)| (*name, self.copy_from(unifier, *ty, type_cache)))
|
||||
.collect(),
|
||||
),
|
||||
params: RefCell::new(
|
||||
@ -192,7 +191,7 @@ impl Unifier {
|
||||
}
|
||||
};
|
||||
let ty = self.add_ty(ty);
|
||||
self.unify(placeholder, ty).unwrap();
|
||||
self.unify_impl(placeholder, ty, false).unwrap();
|
||||
type_cache.insert(representative, ty);
|
||||
ty
|
||||
})
|
||||
@ -210,6 +209,7 @@ impl Unifier {
|
||||
var_id: lock.1,
|
||||
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
|
||||
top_level: None,
|
||||
unify_cache: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -380,7 +380,13 @@ impl Unifier {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unify_call(&mut self, call: &Call, b: Type, signature: &FunSignature, required: &[StrRef]) -> Result<(), String> {
|
||||
pub fn unify_call(
|
||||
&mut self,
|
||||
call: &Call,
|
||||
b: Type,
|
||||
signature: &FunSignature,
|
||||
required: &[StrRef],
|
||||
) -> Result<(), String> {
|
||||
let Call { posargs, kwargs, ret, fun } = call;
|
||||
let instantiated = self.instantiate_fun(b, &*signature);
|
||||
let r = self.get_ty(instantiated);
|
||||
@ -394,13 +400,8 @@ impl Unifier {
|
||||
// we check to make sure that all required arguments (those without default
|
||||
// arguments) are provided, and do not provide the same argument twice.
|
||||
let mut required = required.to_vec();
|
||||
let mut all_names: Vec<_> = signature
|
||||
.borrow()
|
||||
.args
|
||||
.iter()
|
||||
.map(|v| (v.name, v.ty))
|
||||
.rev()
|
||||
.collect();
|
||||
let mut all_names: Vec<_> =
|
||||
signature.borrow().args.iter().map(|v| (v.name, v.ty)).rev().collect();
|
||||
for (i, t) in posargs.iter().enumerate() {
|
||||
if signature.borrow().args.len() <= i {
|
||||
return Err("Too many arguments.".to_string());
|
||||
@ -408,7 +409,7 @@ impl Unifier {
|
||||
if !required.is_empty() {
|
||||
required.pop();
|
||||
}
|
||||
self.unify(all_names.pop().unwrap().1, *t)?;
|
||||
self.unify_impl(all_names.pop().unwrap().1, *t, false)?;
|
||||
}
|
||||
for (k, t) in kwargs.iter() {
|
||||
if let Some(i) = required.iter().position(|v| v == k) {
|
||||
@ -418,17 +419,18 @@ impl Unifier {
|
||||
.iter()
|
||||
.position(|v| &v.0 == k)
|
||||
.ok_or_else(|| format!("Unknown keyword argument {}", k))?;
|
||||
self.unify(all_names.remove(i).1, *t)?;
|
||||
self.unify_impl(all_names.remove(i).1, *t, false)?;
|
||||
}
|
||||
if !required.is_empty() {
|
||||
return Err("Expected more arguments".to_string());
|
||||
}
|
||||
self.unify(*ret, signature.borrow().ret)?;
|
||||
self.unify_impl(*ret, signature.borrow().ret, false)?;
|
||||
*fun.borrow_mut() = Some(instantiated);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||
self.unify_cache.clear();
|
||||
if self.unification_table.unioned(a, b) {
|
||||
Ok(())
|
||||
} else {
|
||||
@ -439,6 +441,16 @@ impl Unifier {
|
||||
fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
|
||||
use TypeEnum::*;
|
||||
use TypeVarMeta::*;
|
||||
|
||||
if !swapped {
|
||||
let rep_a = self.unification_table.get_representative(a);
|
||||
let rep_b = self.unification_table.get_representative(b);
|
||||
if rep_a == rep_b || self.unify_cache.contains(&(rep_a, rep_b)) {
|
||||
return Ok(());
|
||||
}
|
||||
self.unify_cache.insert((rep_a, rep_b));
|
||||
}
|
||||
|
||||
let (ty_a, ty_b) = {
|
||||
(
|
||||
self.unification_table.probe_value(a).clone(),
|
||||
@ -447,8 +459,6 @@ impl Unifier {
|
||||
};
|
||||
match (&*ty_a, &*ty_b) {
|
||||
(TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => {
|
||||
self.occur_check(a, b)?;
|
||||
self.occur_check(b, a)?;
|
||||
match (meta1, meta2) {
|
||||
(Generic, _) => {}
|
||||
(_, Generic) => {
|
||||
@ -458,7 +468,7 @@ impl Unifier {
|
||||
let mut fields2 = fields2.borrow_mut();
|
||||
for (key, value) in fields1.borrow().iter() {
|
||||
if let Some(ty) = fields2.get(key) {
|
||||
self.unify(*ty, *value)?;
|
||||
self.unify_impl(*ty, *value, false)?;
|
||||
} else {
|
||||
fields2.insert(*key, *value);
|
||||
}
|
||||
@ -468,7 +478,7 @@ impl Unifier {
|
||||
let mut map2 = map2.borrow_mut();
|
||||
for (key, value) in map1.borrow().iter() {
|
||||
if let Some(ty) = map2.get(key) {
|
||||
self.unify(*ty, *value)?;
|
||||
self.unify_impl(*ty, *value, false)?;
|
||||
} else {
|
||||
map2.insert(*key, *value);
|
||||
}
|
||||
@ -503,7 +513,6 @@ impl Unifier {
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { meta: Generic, id, range, .. }, _) => {
|
||||
self.occur_check(a, b)?;
|
||||
// We check for the range of the type variable to see if unification is allowed.
|
||||
// Note that although b may be compatible with a, we may have to constrain type
|
||||
// variables in b to make sure that instantiations of b would always be compatible
|
||||
@ -512,11 +521,10 @@ impl Unifier {
|
||||
// guaranteed to be compatible with a under all possible instantiations. So we
|
||||
// unify x with b to recursively apply the constrains, and then set a to x.
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => {
|
||||
self.occur_check(a, b)?;
|
||||
let len = ty.len() as i32;
|
||||
for (k, v) in map.borrow().iter() {
|
||||
// handle negative index
|
||||
@ -527,19 +535,18 @@ impl Unifier {
|
||||
len, k
|
||||
));
|
||||
}
|
||||
self.unify(*v, ty[ind as usize])?;
|
||||
self.unify_impl(*v, ty[ind as usize], false)?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => {
|
||||
self.occur_check(a, b)?;
|
||||
for v in map.borrow().values() {
|
||||
self.unify(*v, *ty)?;
|
||||
self.unify_impl(*v, *ty, false)?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
|
||||
@ -551,30 +558,28 @@ impl Unifier {
|
||||
));
|
||||
}
|
||||
for (x, y) in ty1.iter().zip(ty2.iter()) {
|
||||
self.unify(*x, *y)?;
|
||||
self.unify_impl(*x, *y, false)?;
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||
self.unify(*ty1, *ty2)?;
|
||||
self.unify_impl(*ty1, *ty2, false)?;
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
|
||||
self.occur_check(a, b)?;
|
||||
for (k, v) in map.borrow().iter() {
|
||||
let ty = fields
|
||||
.borrow()
|
||||
.get(k)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No such attribute {}", k))?;
|
||||
self.unify(ty, *v)?;
|
||||
self.unify_impl(ty, *v, false)?;
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
|
||||
self.occur_check(a, b)?;
|
||||
let ty = self.get_ty(*ty);
|
||||
if let TObj { fields, .. } = ty.as_ref() {
|
||||
for (k, v) in map.borrow().iter() {
|
||||
@ -586,14 +591,14 @@ impl Unifier {
|
||||
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
|
||||
return Err(format!("Cannot access field {} for virtual type", k));
|
||||
}
|
||||
self.unify(*v, ty)?;
|
||||
self.unify_impl(*v, ty, false)?;
|
||||
}
|
||||
} else {
|
||||
// require annotation...
|
||||
return Err("Requires type annotation for virtual".to_string());
|
||||
}
|
||||
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
|
||||
self.unify(x, b)?;
|
||||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(
|
||||
@ -604,12 +609,12 @@ impl Unifier {
|
||||
self.incompatible_types(a, b)?;
|
||||
}
|
||||
for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) {
|
||||
self.unify(*x, *y)?;
|
||||
self.unify_impl(*x, *y, false)?;
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||
self.unify(*ty1, *ty2)?;
|
||||
self.unify_impl(*ty1, *ty2, false)?;
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TCall(calls1), TCall(calls2)) => {
|
||||
@ -618,7 +623,6 @@ impl Unifier {
|
||||
calls2.borrow_mut().extend_from_slice(&calls1.borrow());
|
||||
}
|
||||
(TCall(calls), TFunc(signature)) => {
|
||||
self.occur_check(a, b)?;
|
||||
let required: Vec<StrRef> = signature
|
||||
.borrow()
|
||||
.args
|
||||
@ -650,9 +654,9 @@ impl Unifier {
|
||||
if x.default_value != y.default_value {
|
||||
return Err("Functions differ in optional parameters value".to_string());
|
||||
}
|
||||
self.unify(x.ty, y.ty)?;
|
||||
self.unify_impl(x.ty, y.ty, false)?;
|
||||
}
|
||||
self.unify(sign1.ret, sign2.ret)?;
|
||||
self.unify_impl(sign1.ret, sign2.ret, false)?;
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
_ => {
|
||||
@ -818,7 +822,6 @@ impl Unifier {
|
||||
mapping: &VarMap,
|
||||
cache: &mut HashMap<Type, Option<Type>>,
|
||||
) -> Option<Type> {
|
||||
use TypeVarMeta::*;
|
||||
let cached = cache.get_mut(&a);
|
||||
if let Some(cached) = cached {
|
||||
if cached.is_none() {
|
||||
@ -834,7 +837,7 @@ impl Unifier {
|
||||
// should be safe to not implement the substitution for those variants.
|
||||
match &*ty {
|
||||
TypeEnum::TRigidVar { .. } => None,
|
||||
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
|
||||
TypeEnum::TVar { id, .. } => mapping.get(id).cloned(),
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let mut new_ty = Cow::from(ty);
|
||||
for (i, t) in ty.iter().enumerate() {
|
||||
@ -863,7 +866,7 @@ impl Unifier {
|
||||
let need_subst = params.values().any(|v| {
|
||||
let ty = self.unification_table.probe_value(*v);
|
||||
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
|
||||
mapping.contains_key(&id)
|
||||
mapping.contains_key(id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@ -882,7 +885,7 @@ impl Unifier {
|
||||
fields: fields.into(),
|
||||
});
|
||||
if let Some(var) = cache.get(&a).unwrap() {
|
||||
self.unify(new_ty, *var).unwrap();
|
||||
self.unify_impl(new_ty, *var, false).unwrap();
|
||||
}
|
||||
Some(new_ty)
|
||||
} else {
|
||||
@ -914,7 +917,10 @@ impl Unifier {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
_ => {
|
||||
println!("{}", ty.get_type_name());
|
||||
unreachable!("{} not expected", ty.get_type_name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -939,62 +945,6 @@ impl Unifier {
|
||||
map2
|
||||
}
|
||||
|
||||
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
|
||||
use TypeVarMeta::*;
|
||||
if self.unification_table.unioned(a, b) {
|
||||
return Err("Recursive type is prohibited.".to_owned());
|
||||
}
|
||||
let ty = self.unification_table.probe_value(b).clone();
|
||||
|
||||
match ty.as_ref() {
|
||||
TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {}
|
||||
TypeEnum::TVar { meta: Sequence(map), .. } => {
|
||||
for t in map.borrow().values() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TVar { meta: Record(map), .. } => {
|
||||
for t in map.borrow().values() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TCall(calls) => {
|
||||
let call_store = self.calls.clone();
|
||||
for t in calls
|
||||
.borrow()
|
||||
.iter()
|
||||
.map(|call| {
|
||||
let call = call_store[call.0].as_ref();
|
||||
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
|
||||
})
|
||||
.flatten()
|
||||
{
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TTuple { ty } => {
|
||||
for t in ty.iter() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => {
|
||||
self.occur_check(a, *ty)?;
|
||||
}
|
||||
TypeEnum::TObj { params: map, .. } => {
|
||||
for t in map.borrow().values() {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
TypeEnum::TFunc(sig) => {
|
||||
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
||||
for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) {
|
||||
self.occur_check(a, *t)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
|
||||
use TypeEnum::*;
|
||||
let x = self.get_ty(a);
|
||||
|
@ -290,13 +290,6 @@ fn test_unify(
|
||||
(("v1", "v2"), "No such attribute b")
|
||||
; "record obj merge"
|
||||
)]
|
||||
#[test_case(2,
|
||||
&[
|
||||
("v1", "List[v2]"),
|
||||
],
|
||||
(("v1", "v2"), "Recursive type is prohibited.")
|
||||
; "recursive type for lists"
|
||||
)]
|
||||
/// Test cases for invalid unifications.
|
||||
fn test_invalid_unification(
|
||||
variable_count: u32,
|
||||
|
Loading…
Reference in New Issue
Block a user