diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index eb54ce9c..ce02fe13 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -443,19 +443,19 @@ impl<'a> Inferencer<'a> { return Err("Async iterator not supported.".to_string()); } new_context.infer_pattern(&generator.target)?; - let elt = new_context.fold_expr(elt)?; let target = new_context.fold_expr(*generator.target)?; let iter = new_context.fold_expr(*generator.iter)?; + let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); + new_context.unify(iter.custom.unwrap(), list, &iter.location)?; let ifs: Vec<_> = generator .ifs .into_iter() .map(|v| new_context.fold_expr(v)) .collect::>()?; + let elt = new_context.fold_expr(elt)?; // iter should be a list of targets... // actually it should be an iterator of targets, but we don't have iter type for now - let list = new_context.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); - new_context.unify(iter.custom.unwrap(), list, &iter.location)?; // if conditions should be bool for v in ifs.iter() { new_context.unify(v.custom.unwrap(), new_context.primitives.bool, &v.location)?; diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index aea6a326..6c9f8dff 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -19,7 +19,13 @@ struct Resolver { } impl SymbolResolver for Resolver { - fn get_symbol_type(&self, _: &mut Unifier, _: &[Arc>], _: &PrimitiveStore, str: StrRef) -> Option { + fn get_symbol_type( + &self, + _: &mut Unifier, + _: &[Arc>], + _: &PrimitiveStore, + str: StrRef, + ) -> Option { self.id_to_type.get(&str).cloned() } @@ -60,6 +66,14 @@ impl TestEnvironment { fields: HashMap::new().into(), params: HashMap::new().into(), }); + if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) { + let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], + ret: int32, + vars: HashMap::new() + }.into())); + fields.borrow_mut().insert("__add__".into(), add_ty); + } let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), fields: HashMap::new().into(), @@ -132,6 +146,14 @@ impl TestEnvironment { fields: HashMap::new().into(), params: HashMap::new().into(), }); + if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(int32) { + let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], + ret: int32, + vars: HashMap::new() + }.into())); + fields.borrow_mut().insert("__add__".into(), add_ty); + } let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), fields: HashMap::new().into(), @@ -347,6 +369,15 @@ impl TestEnvironment { [("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(), &[] ; "lambda test")] +#[test_case(indoc! {" + a = lambda x: x + x + b = lambda x: a(x) + x + a = b + c = b(1) + "}, + [("a", "fn[[x=int32], int32]"), ("b", "fn[[x=int32], int32]"), ("c", "int32")].iter().cloned().collect(), + &[] + ; "lambda test 2")] #[test_case(indoc! {" a = lambda x: x b = lambda x: x @@ -365,11 +396,10 @@ impl TestEnvironment { &[] ; "obj test")] #[test_case(indoc! {" - f = lambda x: True a = [1, 2, 3] - b = [f(x) for x in a if f(x)] + b = [x + x for x in a] "}, - [("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(), + [("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(), &[] ; "listcomp test")] #[test_case(indoc! {" diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 70c99842..0391fa7b 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,10 +1,9 @@ -use itertools::{chain, zip, Itertools}; -use std::borrow::Cow; +use itertools::{zip, Itertools}; use std::cell::RefCell; use std::collections::HashMap; -use std::iter::once; use std::rc::Rc; use std::sync::{Arc, Mutex}; +use std::{borrow::Cow, collections::HashSet}; use rustpython_parser::ast::StrRef; @@ -105,6 +104,7 @@ pub struct Unifier { unification_table: UnificationTable>, calls: Vec>, var_id: u32, + unify_cache: HashSet<(Type, Type)>, } impl Default for Unifier { @@ -120,6 +120,7 @@ impl Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new(), + unify_cache: HashSet::new(), top_level: None, } } @@ -148,9 +149,7 @@ impl Unifier { fields .borrow() .iter() - .map(|(name, ty)| { - (*name, self.copy_from(unifier, *ty, type_cache)) - }) + .map(|(name, ty)| (*name, self.copy_from(unifier, *ty, type_cache))) .collect(), ), params: RefCell::new( @@ -192,7 +191,7 @@ impl Unifier { } }; let ty = self.add_ty(ty); - self.unify(placeholder, ty).unwrap(); + self.unify_impl(placeholder, ty, false).unwrap(); type_cache.insert(representative, ty); ty }) @@ -210,6 +209,7 @@ impl Unifier { var_id: lock.1, calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(), top_level: None, + unify_cache: HashSet::new(), } } @@ -380,7 +380,13 @@ impl Unifier { } } - pub fn unify_call(&mut self, call: &Call, b: Type, signature: &FunSignature, required: &[StrRef]) -> Result<(), String> { + pub fn unify_call( + &mut self, + call: &Call, + b: Type, + signature: &FunSignature, + required: &[StrRef], + ) -> Result<(), String> { let Call { posargs, kwargs, ret, fun } = call; let instantiated = self.instantiate_fun(b, &*signature); let r = self.get_ty(instantiated); @@ -394,13 +400,8 @@ impl Unifier { // we check to make sure that all required arguments (those without default // arguments) are provided, and do not provide the same argument twice. let mut required = required.to_vec(); - let mut all_names: Vec<_> = signature - .borrow() - .args - .iter() - .map(|v| (v.name, v.ty)) - .rev() - .collect(); + let mut all_names: Vec<_> = + signature.borrow().args.iter().map(|v| (v.name, v.ty)).rev().collect(); for (i, t) in posargs.iter().enumerate() { if signature.borrow().args.len() <= i { return Err("Too many arguments.".to_string()); @@ -408,7 +409,7 @@ impl Unifier { if !required.is_empty() { required.pop(); } - self.unify(all_names.pop().unwrap().1, *t)?; + self.unify_impl(all_names.pop().unwrap().1, *t, false)?; } for (k, t) in kwargs.iter() { if let Some(i) = required.iter().position(|v| v == k) { @@ -418,17 +419,18 @@ impl Unifier { .iter() .position(|v| &v.0 == k) .ok_or_else(|| format!("Unknown keyword argument {}", k))?; - self.unify(all_names.remove(i).1, *t)?; + self.unify_impl(all_names.remove(i).1, *t, false)?; } if !required.is_empty() { return Err("Expected more arguments".to_string()); } - self.unify(*ret, signature.borrow().ret)?; + self.unify_impl(*ret, signature.borrow().ret, false)?; *fun.borrow_mut() = Some(instantiated); Ok(()) } pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> { + self.unify_cache.clear(); if self.unification_table.unioned(a, b) { Ok(()) } else { @@ -439,6 +441,16 @@ impl Unifier { fn unify_impl(&mut self, a: Type, b: Type, swapped: bool) -> Result<(), String> { use TypeEnum::*; use TypeVarMeta::*; + + if !swapped { + let rep_a = self.unification_table.get_representative(a); + let rep_b = self.unification_table.get_representative(b); + if rep_a == rep_b || self.unify_cache.contains(&(rep_a, rep_b)) { + return Ok(()); + } + self.unify_cache.insert((rep_a, rep_b)); + } + let (ty_a, ty_b) = { ( self.unification_table.probe_value(a).clone(), @@ -447,8 +459,6 @@ impl Unifier { }; match (&*ty_a, &*ty_b) { (TVar { meta: meta1, range: range1, .. }, TVar { meta: meta2, range: range2, .. }) => { - self.occur_check(a, b)?; - self.occur_check(b, a)?; match (meta1, meta2) { (Generic, _) => {} (_, Generic) => { @@ -458,7 +468,7 @@ impl Unifier { let mut fields2 = fields2.borrow_mut(); for (key, value) in fields1.borrow().iter() { if let Some(ty) = fields2.get(key) { - self.unify(*ty, *value)?; + self.unify_impl(*ty, *value, false)?; } else { fields2.insert(*key, *value); } @@ -468,7 +478,7 @@ impl Unifier { let mut map2 = map2.borrow_mut(); for (key, value) in map1.borrow().iter() { if let Some(ty) = map2.get(key) { - self.unify(*ty, *value)?; + self.unify_impl(*ty, *value, false)?; } else { map2.insert(*key, *value); } @@ -503,7 +513,6 @@ impl Unifier { self.set_a_to_b(a, b); } (TVar { meta: Generic, id, range, .. }, _) => { - self.occur_check(a, b)?; // We check for the range of the type variable to see if unification is allowed. // Note that although b may be compatible with a, we may have to constrain type // variables in b to make sure that instantiations of b would always be compatible @@ -512,11 +521,10 @@ impl Unifier { // guaranteed to be compatible with a under all possible instantiations. So we // unify x with b to recursively apply the constrains, and then set a to x. let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); - self.unify(x, b)?; + self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } (TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => { - self.occur_check(a, b)?; let len = ty.len() as i32; for (k, v) in map.borrow().iter() { // handle negative index @@ -527,19 +535,18 @@ impl Unifier { len, k )); } - self.unify(*v, ty[ind as usize])?; + self.unify_impl(*v, ty[ind as usize], false)?; } let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); - self.unify(x, b)?; + self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } (TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => { - self.occur_check(a, b)?; for v in map.borrow().values() { - self.unify(*v, *ty)?; + self.unify_impl(*v, *ty, false)?; } let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); - self.unify(x, b)?; + self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { @@ -551,30 +558,28 @@ impl Unifier { )); } for (x, y) in ty1.iter().zip(ty2.iter()) { - self.unify(*x, *y)?; + self.unify_impl(*x, *y, false)?; } self.set_a_to_b(a, b); } (TList { ty: ty1 }, TList { ty: ty2 }) => { - self.unify(*ty1, *ty2)?; + self.unify_impl(*ty1, *ty2, false)?; self.set_a_to_b(a, b); } (TVar { meta: Record(map), id, range, .. }, TObj { fields, .. }) => { - self.occur_check(a, b)?; for (k, v) in map.borrow().iter() { let ty = fields .borrow() .get(k) .copied() .ok_or_else(|| format!("No such attribute {}", k))?; - self.unify(ty, *v)?; + self.unify_impl(ty, *v, false)?; } let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); - self.unify(x, b)?; + self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } (TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => { - self.occur_check(a, b)?; let ty = self.get_ty(*ty); if let TObj { fields, .. } = ty.as_ref() { for (k, v) in map.borrow().iter() { @@ -586,14 +591,14 @@ impl Unifier { if !matches!(self.get_ty(ty).as_ref(), TFunc { .. }) { return Err(format!("Cannot access field {} for virtual type", k)); } - self.unify(*v, ty)?; + self.unify_impl(*v, ty, false)?; } } else { // require annotation... return Err("Requires type annotation for virtual".to_string()); } let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); - self.unify(x, b)?; + self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } ( @@ -604,12 +609,12 @@ impl Unifier { self.incompatible_types(a, b)?; } for (x, y) in zip(params1.borrow().values(), params2.borrow().values()) { - self.unify(*x, *y)?; + self.unify_impl(*x, *y, false)?; } self.set_a_to_b(a, b); } (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { - self.unify(*ty1, *ty2)?; + self.unify_impl(*ty1, *ty2, false)?; self.set_a_to_b(a, b); } (TCall(calls1), TCall(calls2)) => { @@ -618,7 +623,6 @@ impl Unifier { calls2.borrow_mut().extend_from_slice(&calls1.borrow()); } (TCall(calls), TFunc(signature)) => { - self.occur_check(a, b)?; let required: Vec = signature .borrow() .args @@ -650,9 +654,9 @@ impl Unifier { if x.default_value != y.default_value { return Err("Functions differ in optional parameters value".to_string()); } - self.unify(x.ty, y.ty)?; + self.unify_impl(x.ty, y.ty, false)?; } - self.unify(sign1.ret, sign2.ret)?; + self.unify_impl(sign1.ret, sign2.ret, false)?; self.set_a_to_b(a, b); } _ => { @@ -818,7 +822,6 @@ impl Unifier { mapping: &VarMap, cache: &mut HashMap>, ) -> Option { - use TypeVarMeta::*; let cached = cache.get_mut(&a); if let Some(cached) = cached { if cached.is_none() { @@ -834,7 +837,7 @@ impl Unifier { // should be safe to not implement the substitution for those variants. match &*ty { TypeEnum::TRigidVar { .. } => None, - TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(), + TypeEnum::TVar { id, .. } => mapping.get(id).cloned(), TypeEnum::TTuple { ty } => { let mut new_ty = Cow::from(ty); for (i, t) in ty.iter().enumerate() { @@ -863,7 +866,7 @@ impl Unifier { let need_subst = params.values().any(|v| { let ty = self.unification_table.probe_value(*v); if let TypeEnum::TVar { id, .. } = ty.as_ref() { - mapping.contains_key(&id) + mapping.contains_key(id) } else { false } @@ -882,7 +885,7 @@ impl Unifier { fields: fields.into(), }); if let Some(var) = cache.get(&a).unwrap() { - self.unify(new_ty, *var).unwrap(); + self.unify_impl(new_ty, *var, false).unwrap(); } Some(new_ty) } else { @@ -914,7 +917,10 @@ impl Unifier { None } } - _ => unimplemented!(), + _ => { + println!("{}", ty.get_type_name()); + unreachable!("{} not expected", ty.get_type_name()) + } } } @@ -939,62 +945,6 @@ impl Unifier { map2 } - fn occur_check(&mut self, a: Type, b: Type) -> Result<(), String> { - use TypeVarMeta::*; - if self.unification_table.unioned(a, b) { - return Err("Recursive type is prohibited.".to_owned()); - } - let ty = self.unification_table.probe_value(b).clone(); - - match ty.as_ref() { - TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {} - TypeEnum::TVar { meta: Sequence(map), .. } => { - for t in map.borrow().values() { - self.occur_check(a, *t)?; - } - } - TypeEnum::TVar { meta: Record(map), .. } => { - for t in map.borrow().values() { - self.occur_check(a, *t)?; - } - } - TypeEnum::TCall(calls) => { - let call_store = self.calls.clone(); - for t in calls - .borrow() - .iter() - .map(|call| { - let call = call_store[call.0].as_ref(); - chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret)) - }) - .flatten() - { - self.occur_check(a, *t)?; - } - } - TypeEnum::TTuple { ty } => { - for t in ty.iter() { - self.occur_check(a, *t)?; - } - } - TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => { - self.occur_check(a, *ty)?; - } - TypeEnum::TObj { params: map, .. } => { - for t in map.borrow().values() { - self.occur_check(a, *t)?; - } - } - TypeEnum::TFunc(sig) => { - let FunSignature { args, ret, vars: params } = &*sig.borrow(); - for t in chain!(args.iter().map(|v| &v.ty), params.values(), once(ret)) { - self.occur_check(a, *t)?; - } - } - } - Ok(()) - } - fn get_intersection(&mut self, a: Type, b: Type) -> Result, ()> { use TypeEnum::*; let x = self.get_ty(a); diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index fe56d2d7..064f2f9f 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -290,13 +290,6 @@ fn test_unify( (("v1", "v2"), "No such attribute b") ; "record obj merge" )] -#[test_case(2, - &[ - ("v1", "List[v2]"), - ], - (("v1", "v2"), "Recursive type is prohibited.") - ; "recursive type for lists" -)] /// Test cases for invalid unifications. fn test_invalid_unification( variable_count: u32,