forked from M-Labs/nac3
1
0
Fork 0

nac3core/typecheck: supports recursive type inference

This commit is contained in:
pca006132 2021-10-16 15:56:13 +08:00
parent fd0b11087e
commit 26076c37ba
4 changed files with 91 additions and 118 deletions

View File

@ -443,19 +443,19 @@ impl<'a> Inferencer<'a> {
return Err("Async iterator not supported.".to_string());
}
new_context.infer_pattern(&generator.target)?;
let elt = new_context.fold_expr(elt)?;
let target = new_context.fold_expr(*generator.target)?;
let iter = new_context.fold_expr(*generator.iter)?;
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
let ifs: Vec<_> = generator
.ifs
.into_iter()
.map(|v| new_context.fold_expr(v))
.collect::<Result<_, _>>()?;
let elt = new_context.fold_expr(elt)?;
// iter should be a list of targets...
// actually it should be an iterator of targets, but we don't have iter type for now
let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
new_context.unify(iter.custom.unwrap(), list, &iter.location)?;
// if conditions should be bool
for v in ifs.iter() {
new_context.unify(v.custom.unwrap(), new_context.primitives.bool, &v.location)?;

View File

@ -19,7 +19,13 @@ struct Resolver {
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &[Arc<RwLock<TopLevelDef>>], _: &PrimitiveStore, str: StrRef) -> Option<Type> {
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
) -> Option<Type> {
self.id_to_type.get(&str).cloned()
}
@ -60,6 +66,14 @@ impl TestEnvironment {
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: HashMap::new()
}.into()));
fields.borrow_mut().insert("__add__".into(), add_ty);
}
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
@ -132,6 +146,14 @@ impl TestEnvironment {
fields: HashMap::new().into(),
params: HashMap::new().into(),
});
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: HashMap::new()
}.into()));
fields.borrow_mut().insert("__add__".into(), add_ty);
}
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
@ -347,6 +369,15 @@ impl TestEnvironment {
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
&[]
; "lambda test")]
#[test_case(indoc! {"
a = lambda x: x + x
b = lambda x: a(x) + x
a = b
c = b(1)
"},
[("a", "fn[[x=int32], int32]"), ("b", "fn[[x=int32], int32]"), ("c", "int32")].iter().cloned().collect(),
&[]
; "lambda test 2")]
#[test_case(indoc! {"
a = lambda x: x
b = lambda x: x
@ -365,11 +396,10 @@ impl TestEnvironment {
&[]
; "obj test")]
#[test_case(indoc! {"
f = lambda x: True
a = [1, 2, 3]
b = [f(x) for x in a if f(x)]
b = [x + x for x in a]
"},
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
&[]
; "listcomp test")]
#[test_case(indoc! {"

View File

@ -1,10 +1,9 @@
use itertools::{chain, zip, Itertools};
use std::borrow::Cow;
use itertools::{zip, Itertools};
use std::cell::RefCell;
use std::collections::HashMap;
use std::iter::once;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
use rustpython_parser::ast::StrRef;
@ -105,6 +104,7 @@ pub struct Unifier {
unification_table: UnificationTable<Rc<TypeEnum>>,
calls: Vec<Rc<Call>>,
var_id: u32,
unify_cache: HashSet<(Type, Type)>,
}
impl Default for Unifier {
@ -120,6 +120,7 @@ impl Unifier {
unification_table: UnificationTable::new(),
var_id: 0,
calls: Vec::new(),
unify_cache: HashSet::new(),
top_level: None,
}
}
@ -148,9 +149,7 @@ impl Unifier {
fields
.borrow()
.iter()
.map(|(name, ty)| {
(*name, self.copy_from(unifier, *ty, type_cache))
})
.map(|(name, ty)| (*name, self.copy_from(unifier, *ty, type_cache)))
.collect(),
),
params: RefCell::new(
@ -192,7 +191,7 @@ impl Unifier {
}
};
let ty = self.add_ty(ty);
self.unify(placeholder, ty).unwrap();
self.unify_impl(placeholder, ty, false).unwrap();
type_cache.insert(representative, ty);
ty
})
@ -210,6 +209,7 @@ impl Unifier {
var_id: lock.1,
calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(),
top_level: None,
unify_cache: HashSet::new(),
}
}
@ -380,7 +380,13 @@ impl Unifier {
}
}
pub fn unify_call(&mut self, call: &Call, b: Type, signature: &FunSignature, required: &[StrRef]) -> Result<(), String> {
pub fn unify_call(
&mut self,
call: &Call,
b: Type,
signature: &FunSignature,
required: &[StrRef],
) -> Result<(), String> {
let Call { posargs, kwargs, ret, fun } = call;
let instantiated = self.instantiate_fun(b, &*signature);
let r = self.get_ty(instantiated);
@ -394,13 +400,8 @@ impl Unifier {
// we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice.
let mut required = required.to_vec();
let mut all_names: Vec<_> = signature
.borrow()
.args
.iter()
.map(|v| (v.name, v.ty))
.rev()
.collect();
let mut all_names: Vec<_> =
signature.borrow().args.iter().map(|v| (v.name, v.ty)).rev().collect();
for (i, t) in posargs.iter().enumerate() {
if signature.borrow().args.len() <= i {
return Err("Too many arguments.".to_string());
@ -408,7 +409,7 @@ impl Unifier {
if !required.is_empty() {
required.pop();
}
self.unify(all_names.pop().unwrap().1, *t)?;
self.unify_impl(all_names.pop().unwrap().1, *t, false)?;
}
for (k, t) in kwargs.iter() {
if let Some(i) = required.iter().position(|v| v == k) {
@ -418,17 +419,18 @@ impl Unifier {
.iter()
.position(|v| &v.0 == k)
.ok_or_else(|| format!("Unknown keyword argument {}", k))?;
self.unify(all_names.remove(i).1, *t)?;
self.unify_impl(all_names.remove(i).1, *t, false)?;
}
if !required.is_empty() {
return Err("Expected more arguments".to_string());
}
self.unify(*ret, signature.borrow().ret)?;
self.unify_impl(*ret, signature.borrow().ret, false)?;
*fun.borrow_mut() = Some(instantiated);
Ok(())
}
pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> {
self.unify_cache.clear();
if self.unification_table.unioned(a, b) {
Ok(())
} else {
@ -439,6 +441,16 @@ impl Unifier {
fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> {
use TypeEnum::*;
use TypeVarMeta::*;
if !swapped {
let rep_a = self.unification_table.get_representative(a);
let rep_b = self.unification_table.get_representative(b);
if rep_a == rep_b || self.unify_cache.contains(&(rep_a, rep_b)) {
return Ok(());
}
self.unify_cache.insert((rep_a, rep_b));
}
let (ty_a, ty_b) = {
(
self.unification_table.probe_value(a).clone(),
@ -447,8 +459,6 @@ impl Unifier {
};
match (&*ty_a, &*ty_b) {
(TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => {
self.occur_check(a, b)?;
self.occur_check(b, a)?;
match (meta1, meta2) {
(Generic, _) => {}
(_, Generic) => {
@ -458,7 +468,7 @@ impl Unifier {
let mut fields2 = fields2.borrow_mut();
for (key, value) in fields1.borrow().iter() {
if let Some(ty) = fields2.get(key) {
self.unify(*ty, *value)?;
self.unify_impl(*ty, *value, false)?;
} else {
fields2.insert(*key, *value);
}
@ -468,7 +478,7 @@ impl Unifier {
let mut map2 = map2.borrow_mut();
for (key, value) in map1.borrow().iter() {
if let Some(ty) = map2.get(key) {
self.unify(*ty, *value)?;
self.unify_impl(*ty, *value, false)?;
} else {
map2.insert(*key, *value);
}
@ -503,7 +513,6 @@ impl Unifier {
self.set_a_to_b(a, b);
}
(TVar { meta: Generic, id, range, .. }, _) => {
self.occur_check(a, b)?;
// We check for the range of the type variable to see if unification is allowed.
// Note that although b may be compatible with a, we may have to constrain type
// variables in b to make sure that instantiations of b would always be compatible
@ -512,11 +521,10 @@ impl Unifier {
// guaranteed to be compatible with a under all possible instantiations. So we
// unify x with b to recursively apply the constrains, and then set a to x.
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => {
self.occur_check(a, b)?;
let len = ty.len() as i32;
for (k, v) in map.borrow().iter() {
// handle negative index
@ -527,19 +535,18 @@ impl Unifier {
len, k
));
}
self.unify(*v, ty[ind as usize])?;
self.unify_impl(*v, ty[ind as usize], false)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => {
self.occur_check(a, b)?;
for v in map.borrow().values() {
self.unify(*v, *ty)?;
self.unify_impl(*v, *ty, false)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
@ -551,30 +558,28 @@ impl Unifier {
));
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
self.unify(*x, *y)?;
self.unify_impl(*x, *y, false)?;
}
self.set_a_to_b(a, b);
}
(TList { ty: ty1 }, TList { ty: ty2 }) => {
self.unify(*ty1, *ty2)?;
self.unify_impl(*ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => {
self.occur_check(a, b)?;
for (k, v) in map.borrow().iter() {
let ty = fields
.borrow()
.get(k)
.copied()
.ok_or_else(|| format!("No such attribute {}", k))?;
self.unify(ty, *v)?;
self.unify_impl(ty, *v, false)?;
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
self.occur_check(a, b)?;
let ty = self.get_ty(*ty);
if let TObj { fields, .. } = ty.as_ref() {
for (k, v) in map.borrow().iter() {
@ -586,14 +591,14 @@ impl Unifier {
if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) {
return Err(format!("Cannot access field {} for virtual type", k));
}
self.unify(*v, ty)?;
self.unify_impl(*v, ty, false)?;
}
} else {
// require annotation...
return Err("Requires type annotation for virtual".to_string());
}
let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.unify(x, b)?;
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(
@ -604,12 +609,12 @@ impl Unifier {
self.incompatible_types(a, b)?;
}
for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) {
self.unify(*x, *y)?;
self.unify_impl(*x, *y, false)?;
}
self.set_a_to_b(a, b);
}
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
self.unify(*ty1, *ty2)?;
self.unify_impl(*ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TCall(calls1), TCall(calls2)) => {
@ -618,7 +623,6 @@ impl Unifier {
calls2.borrow_mut().extend_from_slice(&calls1.borrow());
}
(TCall(calls), TFunc(signature)) => {
self.occur_check(a, b)?;
let required: Vec<StrRef> = signature
.borrow()
.args
@ -650,9 +654,9 @@ impl Unifier {
if x.default_value != y.default_value {
return Err("Functions differ in optional parameters value".to_string());
}
self.unify(x.ty, y.ty)?;
self.unify_impl(x.ty, y.ty, false)?;
}
self.unify(sign1.ret, sign2.ret)?;
self.unify_impl(sign1.ret, sign2.ret, false)?;
self.set_a_to_b(a, b);
}
_ => {
@ -818,7 +822,6 @@ impl Unifier {
mapping: &VarMap,
cache: &mut HashMap<Type, Option<Type>>,
) -> Option<Type> {
use TypeVarMeta::*;
let cached = cache.get_mut(&a);
if let Some(cached) = cached {
if cached.is_none() {
@ -834,7 +837,7 @@ impl Unifier {
// should be safe to not implement the substitution for those variants.
match &*ty {
TypeEnum::TRigidVar { .. } => None,
TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(),
TypeEnum::TVar { id, .. } => mapping.get(id).cloned(),
TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() {
@ -863,7 +866,7 @@ impl Unifier {
let need_subst = params.values().any(|v| {
let ty = self.unification_table.probe_value(*v);
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
mapping.contains_key(&id)
mapping.contains_key(id)
} else {
false
}
@ -882,7 +885,7 @@ impl Unifier {
fields: fields.into(),
});
if let Some(var) = cache.get(&a).unwrap() {
self.unify(new_ty, *var).unwrap();
self.unify_impl(new_ty, *var, false).unwrap();
}
Some(new_ty)
} else {
@ -914,7 +917,10 @@ impl Unifier {
None
}
}
_ => unimplemented!(),
_ => {
println!("{}", ty.get_type_name());
unreachable!("{} not expected", ty.get_type_name())
}
}
}
@ -939,62 +945,6 @@ impl Unifier {
map2
}
fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> {
use TypeVarMeta::*;
if self.unification_table.unioned(a, b) {
return Err("Recursive type is prohibited.".to_owned());
}
let ty = self.unification_table.probe_value(b).clone();
match ty.as_ref() {
TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {}
TypeEnum::TVar { meta: Sequence(map), .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TVar { meta: Record(map), .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TCall(calls) => {
let call_store = self.calls.clone();
for t in calls
.borrow()
.iter()
.map(|call| {
let call = call_store[call.0].as_ref();
chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))
})
.flatten()
{
self.occur_check(a, *t)?;
}
}
TypeEnum::TTuple { ty } => {
for t in ty.iter() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => {
self.occur_check(a, *ty)?;
}
TypeEnum::TObj { params: map, .. } => {
for t in map.borrow().values() {
self.occur_check(a, *t)?;
}
}
TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow();
for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) {
self.occur_check(a, *t)?;
}
}
}
Ok(())
}
fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
use TypeEnum::*;
let x = self.get_ty(a);

View File

@ -290,13 +290,6 @@ fn test_unify(
(("v1", "v2"), "No such attribute b")
; "record obj merge"
)]
#[test_case(2,
&[
("v1", "List[v2]"),
],
(("v1", "v2"), "Recursive type is prohibited.")
; "recursive type for lists"
)]
/// Test cases for invalid unifications.
fn test_invalid_unification(
variable_count: u32,