diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0966d6a0..66380a3c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -538,13 +538,25 @@ impl InnerResolver { let types = types?; Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) } - (TypeEnum::TObj { params: var_map, fields, .. }, false) => { - self.pyid_to_type.write().insert(ty_id, extracted_ty); + (TypeEnum::TObj { params, fields, .. }, false) => { + let var_map = params + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, name, loc, .. } = + &*unifier.get_ty(*ty) + { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) + } else { + unreachable!() + } + }) + .collect::>(); let mut instantiate_obj = || { // loop through non-function fields of the class to get the instantiated value for field in fields.iter() { let name: String = (*field.0).into(); - if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) { + if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1.0) { continue; } else { let field_data = obj.getattr(&name)?; @@ -560,7 +572,7 @@ impl InnerResolver { } }; let field_ty = - unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); + unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0); if let Err(e) = unifier.unify(ty, field_ty) { // field type mismatch return Ok(Err(format!( @@ -577,14 +589,10 @@ impl InnerResolver { return Ok(Err("object is not of concrete type".into())); } } + let extracted_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty); Ok(Ok(extracted_ty)) }; - let result = instantiate_obj(); - // do not cache the type if there are errors - if matches!(result, Err(_) | Ok(Err(_))) { - self.pyid_to_type.write().remove(&ty_id); - } - result + instantiate_obj() } _ => Ok(Ok(extracted_ty)), } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index b9e75e62..fd96e5d9 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -353,6 +353,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index]; (Unifier::from_shared_unifier(unifier), *primitives) }; + unifier.top_level = Some(top_level_ctx.clone()); let mut cache = HashMap::new(); for (a, b) in task.subst.iter() { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index b3a54432..6350f698 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -735,10 +735,11 @@ impl TopLevelComposer { } } + let mut subst_list = Some(Vec::new()); // unification of previously assigned typevar let mut unification_helper = |ty, def| { let target_ty = - get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def)?; + get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def, &mut subst_list)?; unifier.unify(ty, target_ty).map_err(|e| e.to_display(unifier).to_string())?; Ok(()) as Result<(), String> }; @@ -747,6 +748,29 @@ impl TopLevelComposer { errors.insert(e); } } + for ty in subst_list.unwrap().into_iter() { + if let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) { + let mut new_fields = HashMap::new(); + let mut need_subst = false; + for (name, (ty, mutable)) in fields.iter() { + let substituted = unifier.subst(*ty, params); + need_subst |= substituted.is_some(); + new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable)); + } + if need_subst { + let new_ty = unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + params: params.clone(), + fields: new_fields, + }); + if let Err(e) = unifier.unify(ty, new_ty) { + errors.insert(e.to_display(unifier).to_string()); + } + } + } else { + unreachable!() + } + } if !errors.is_empty() { return Err(errors.into_iter().sorted().join("\n----------\n")); } @@ -867,6 +891,7 @@ impl TopLevelComposer { unifier, primitives_store, &type_annotation, + &mut None )?; Ok(FuncArg { @@ -934,6 +959,7 @@ impl TopLevelComposer { unifier, primitives_store, &return_ty_annotation, + &mut None )? } else { primitives_store.none @@ -1498,6 +1524,7 @@ impl TopLevelComposer { unifier, primitives_ty, &make_self_type_annotation(type_vars, *object_id), + &mut None )?; if ancestors .iter() @@ -1666,6 +1693,7 @@ impl TopLevelComposer { unifier, primitives_ty, &ty_ann, + &mut None )?; Some((self_ty, type_vars.clone())) } else { diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 9fc1f80c..3f5405db 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -273,6 +273,7 @@ pub fn get_type_from_type_annotation_kinds( unifier: &mut Unifier, primitives: &PrimitiveStore, ann: &TypeAnnotation, + subst_list: &mut Option> ) -> Result { match ann { TypeAnnotation::CustomClass { id: obj_id, params } => { @@ -294,6 +295,7 @@ pub fn get_type_from_type_annotation_kinds( unifier, primitives, x, + subst_list ) }) .collect::, _>>()?; @@ -349,12 +351,16 @@ pub fn get_type_from_type_annotation_kinds( let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); (*name, (subst_ty, *mutability)) })); - - Ok(unifier.add_ty(TypeEnum::TObj { + let need_subst = !subst.is_empty(); + let ty = unifier.add_ty(TypeEnum::TObj { obj_id: *obj_id, fields: tobj_fields, params: subst, - })) + }); + if need_subst { + subst_list.as_mut().map(|wl| wl.push(ty)); + } + Ok(ty) } } else { unreachable!("should be class def here") @@ -367,6 +373,7 @@ pub fn get_type_from_type_annotation_kinds( unifier, primitives, ty.as_ref(), + subst_list )?; Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) } @@ -376,6 +383,7 @@ pub fn get_type_from_type_annotation_kinds( unifier, primitives, ty.as_ref(), + subst_list )?; Ok(unifier.add_ty(TypeEnum::TList { ty })) } @@ -383,7 +391,7 @@ pub fn get_type_from_type_annotation_kinds( let tys = tys .iter() .map(|x| { - get_type_from_type_annotation_kinds(top_level_defs, unifier, primitives, x) + get_type_from_type_annotation_kinds(top_level_defs, unifier, primitives, x, subst_list) }) .collect::, _>>()?; Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index b13628b8..67b89248 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -520,7 +520,7 @@ impl Unifier { match (&*ty_a, &*ty_b) { ( TVar { fields: fields1, id, name: name1, loc: loc1, .. }, - TVar { fields: fields2, name: name2, loc: loc2, .. }, + TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. }, ) => { let new_fields = match (fields1, fields2) { (None, None) => None, @@ -570,7 +570,7 @@ impl Unifier { self.unification_table.set_value( a, Rc::new(TypeEnum::TVar { - id: *id, + id: name1.map_or(*id2, |_| *id), fields: new_fields, range, name: name1.or(*name2), diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index bf99cafb..37925157 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -85,7 +85,7 @@ fn main() { Default::default(), )?; get_type_from_type_annotation_kinds( - def_list, unifier, primitives, &ty, + def_list, unifier, primitives, &ty, &mut None ) }) .collect::, _>>()?;