diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index 161eb50..c1f8cff 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -2,6 +2,7 @@ use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; use generational_arena::{Arena, Index}; use std::cell::RefCell; use std::collections::BTreeMap; +use std::iter::{empty, once, Iterator}; use std::mem::swap; use std::rc::Rc; @@ -50,6 +51,7 @@ impl UnifyKey for Type { type Mapping = BTreeMap; type VarMap = Mapping; +#[derive(Clone)] struct Call { posargs: Vec, kwargs: BTreeMap, @@ -57,6 +59,7 @@ struct Call { fn_id: usize, } +#[derive(Clone)] struct FuncArg { name: String, ty: Type, @@ -320,6 +323,64 @@ impl Unifier { )) } + fn occur_check(&self, a: TypeIndex, b: Type) -> Result<(), String> { + let i_b = self.unification_table.borrow_mut().probe_value(b); + if a == i_b { + return Err("Recursive type detected!".to_owned()); + } + let ty = self.type_arena.borrow().get(i_b.0).unwrap().clone(); + let ty = ty.borrow(); + + match &*ty { + TypeEnum::TVar { .. } => { + // TODO: occur check for bounds... + } + TypeEnum::TSeq { map } | TypeEnum::TObj { params: map, .. } => { + for t in map.values() { + self.occur_check(a, *t)?; + } + } + TypeEnum::TTuple { ty } => { + for t in ty.iter() { + self.occur_check(a, *t)?; + } + } + TypeEnum::TList { ty } | TypeEnum::TVirtual { ty } => { + self.occur_check(a, *ty)?; + } + TypeEnum::TRecord { fields } => { + for t in fields.values() { + self.occur_check(a, *t)?; + } + } + TypeEnum::TCall { calls } => { + for t in calls + .iter() + .map(|call| { + call.posargs + .iter() + .chain(call.kwargs.values()) + .chain(once(&call.ret)) + }) + .flatten() + { + self.occur_check(a, *t)?; + } + } + TypeEnum::TFunc { args, ret, params } => { + for t in args + .iter() + .map(|v| &v.ty) + .chain(params.values()) + .chain(once(ret)) + { + self.occur_check(a, *t)?; + } + } + }; + Ok(()) + } + fn subst(&self, a: Type, mapping: &VarMap) -> Option { let index = self.unification_table.borrow_mut().probe_value(a); let ty_cell = { @@ -406,12 +467,44 @@ impl Unifier { obj_id: *obj_id, params: self .subst_map(¶ms, mapping) - .or_else(|| Some(params.clone())) - .unwrap(), + .unwrap_or_else(|| params.clone()), fields: self .subst_map(&fields, mapping) - .or_else(|| Some(fields.clone())) - .unwrap(), + .unwrap_or_else(|| fields.clone()), + } + .into(), + )); + Some( + self.unification_table + .borrow_mut() + .new_key(TypeIndex(index)), + ) + } else { + None + } + } + TypeEnum::TFunc { args, ret, params } => { + let new_params = self.subst_map(params, mapping); + let new_ret = self.subst(*ret, mapping); + let mut new_args = None; + for (i, t) in args.iter().enumerate() { + if let Some(t1) = self.subst(t.ty, mapping) { + if new_args.is_none() { + new_args = Some(args.clone()); + } + new_args.as_mut().unwrap()[i] = FuncArg { + name: t.name.clone(), + ty: t1, + is_optional: t.is_optional, + }; + } + } + if new_params.is_some() || new_ret.is_some() || new_args.is_some() { + let index = self.type_arena.borrow_mut().insert(Rc::new( + TypeEnum::TFunc { + params: new_params.unwrap_or_else(|| params.clone()), + ret: new_ret.unwrap_or_else(|| *ret), + args: new_args.unwrap_or_else(|| args.clone()), } .into(), ));