diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 152b5975..4c2dd5cd 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -503,7 +503,7 @@ impl InnerResolver { Ok(s) => s, Err(e) => return Ok(Err(e)), }; - return match (&*unifier.get_ty(extracted_ty), inst_check) { + match (&*unifier.get_ty(extracted_ty), inst_check) { // do the instantiation for these three types (TypeEnum::TList { ty }, false) => { let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; @@ -538,22 +538,9 @@ impl InnerResolver { let types = types?; Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) } - (TypeEnum::TObj { params, fields, .. }, false) => { + (TypeEnum::TObj { params: var_map, fields, .. }, false) => { self.pyid_to_type.write().insert(ty_id, extracted_ty); let mut instantiate_obj = || { - let var_map = params - .iter() - .map(|(id_var, ty)| { - if let TypeEnum::TVar { id, range, name, loc, .. } = - &*unifier.get_ty(*ty) - { - assert_eq!(*id, *id_var); - (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) - } else { - unreachable!() - } - }) - .collect::>(); // loop through non-function fields of the class to get the instantiated value for field in fields.iter() { let name: String = (*field.0).into(); @@ -590,7 +577,7 @@ impl InnerResolver { return Ok(Err("object is not of concrete type".into())); } } - Ok(Ok(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))) + Ok(Ok(extracted_ty)) }; let result = instantiate_obj(); // do not cache the type if there are errors @@ -600,7 +587,7 @@ impl InnerResolver { result } _ => Ok(Ok(extracted_ty)), - }; + } } fn get_obj_value<'ctx, 'a>( @@ -919,6 +906,7 @@ impl SymbolResolver for Resolver { }) .unwrap(), }; + println!("{:?}", result); if let Ok(t) = &result { self.0.id_to_type.write().insert(str, *t); } diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 186453f2..d5df9f2a 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -229,10 +229,10 @@ pub fn parse_type_annotation( Err(format!("Cannot use function name as type at {}", loc)) } } - Err(e) => { + Err(_) => { let ty = resolver .get_symbol_type(unifier, top_level_defs, primitives, *id) - .map_err(|_| format!("Unknown type annotation at {}: {}", loc, e))?; + .map_err(|e| format!("Unknown type annotation at {}: {}", loc, e))?; if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { Ok(ty) } else { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 83cd9cee..5d660533 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,4 +1,5 @@ use nac3parser::ast::fold::Fold; +use std::rc::Rc; use crate::{ codegen::{expr::get_subst_key, stmt::exn_constructor}, @@ -1192,18 +1193,9 @@ impl TopLevelComposer { unreachable!("must be type var annotation"); } } - let dummy_return_type = unifier.get_dummy_var().0; - type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); - dummy_return_type + get_type_from_type_annotation_kinds(temp_def_list, unifier, primitives, &annotation)? } else { - // if do not have return annotation, return none - // for uniform handling, still use type annoatation - let dummy_return_type = unifier.get_dummy_var().0; - type_var_to_concrete_def.insert( - dummy_return_type, - TypeAnnotation::Primitive(primitives.none), - ); - dummy_return_type + primitives.none } }; @@ -1449,6 +1441,34 @@ impl TopLevelComposer { let primitives_ty = &self.primitives_ty; let definition_ast_list = &self.definition_ast_list; let unifier = &mut self.unifier; + + // first, fix function typevar ids + // they may be changed with our use of placeholders + for (def, _) in definition_ast_list.iter().skip(self.builtin_num) { + if let TopLevelDef::Function { + signature, + var_id, + .. + } = &mut *def.write() { + if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = + unifier.get_ty(*signature).as_ref() { + let new_var_ids = vars.values().map(|v| match &*unifier.get_ty(*v) { + TypeEnum::TVar{id, ..} => *id, + _ => unreachable!(), + }).collect_vec(); + if new_var_ids != *var_id { + let new_signature = FunSignature { + args: args.clone(), + ret: ret.clone(), + vars: new_var_ids.iter().zip(vars.values()).map(|(id, v)| (*id, v.clone())).collect(), + }; + unifier.unification_table.set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature))); + *var_id = new_var_ids; + } + } + } + } + let mut errors = HashSet::new(); let mut analyze = |i, def: &Arc>, ast: &Option| { let class_def = def.read(); @@ -1650,7 +1670,6 @@ impl TopLevelComposer { // if class methods, `vars` also contains all class typevars here let (type_var_subst_comb, no_range_vars) = { let mut no_ranges: Vec = Vec::new(); - let var_ids = vars.keys().copied().collect_vec(); let var_combs = vars .iter() .map(|(_, ty)| { @@ -1669,7 +1688,7 @@ impl TopLevelComposer { .collect_vec(); let mut result: Vec> = Default::default(); for comb in var_combs { - result.push(var_ids.clone().into_iter().zip(comb).collect()); + result.push(insted_vars.clone().into_iter().zip(comb).collect()); } // NOTE: if is empty, means no type var, append a empty subst, ok to do this? if result.is_empty() { diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 9fc1f80c..3d15a760 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -105,8 +105,6 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { - let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0; - unifier.unify(var, ty).unwrap(); Ok(TypeAnnotation::TypeVar(ty)) } else { Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location)) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index b58c160b..b13628b8 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -156,9 +156,9 @@ pub type SharedUnifier = Arc, u32, Vec)> #[derive(Clone)] pub struct Unifier { - pub top_level: Option>, - unification_table: UnificationTable>, - pub(super) calls: Vec>, + pub(crate) top_level: Option>, + pub(crate) unification_table: UnificationTable>, + pub(crate) calls: Vec>, var_id: u32, unify_cache: HashSet<(Type, Type)>, snapshot: Option<(usize, u32)>