diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 5255100e..e104ec56 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -531,26 +531,30 @@ impl InnerResolver { } (TypeEnum::TObj { params, fields, .. }, false) => { self.pyid_to_type.write().insert(ty_id, extracted_ty); - let var_map = params - .iter() - .map(|(id_var, ty)| { - if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) { - assert_eq!(*id, *id_var); - (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) + let mut instantiate_obj = || { + let var_map = params + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, name, loc, .. } = + &*unifier.get_ty(*ty) + { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) + } else { + unreachable!() + } + }) + .collect::>(); + // loop through non-function fields of the class to get the instantiated value + for field in fields.iter() { + let name: String = (*field.0).into(); + if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) { + continue; } else { - unreachable!() - } - }) - .collect::>(); - // loop through non-function fields of the class to get the instantiated value - for field in fields.iter() { - let name: String = (*field.0).into(); - if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) { - continue; - } else { - let field_data = obj.getattr(&name)?; - let ty = - match self.get_obj_type(py, field_data, unifier, defs, primitives)? { + let field_data = obj.getattr(&name)?; + let ty = match self + .get_obj_type(py, field_data, unifier, defs, primitives)? + { Ok(t) => t, Err(e) => { return Ok(Err(format!( @@ -559,24 +563,32 @@ impl InnerResolver { ))) } }; - let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); - if let Err(e) = unifier.unify(ty, field_ty) { - // field type mismatch - return Ok(Err(format!( - "error when getting type of field `{}` ({})", - name, - e.to_display(unifier).to_string() - ))); + let field_ty = + unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0); + if let Err(e) = unifier.unify(ty, field_ty) { + // field type mismatch + return Ok(Err(format!( + "error when getting type of field `{}` ({})", + name, + e.to_display(unifier).to_string() + ))); + } } } - } - for (_, ty) in var_map.iter() { - // must be concrete type - if !unifier.is_concrete(*ty, &[]) { - return Ok(Err("object is not of concrete type".into())); + for (_, ty) in var_map.iter() { + // must be concrete type + if !unifier.is_concrete(*ty, &[]) { + return Ok(Err("object is not of concrete type".into())); + } } + Ok(Ok(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))) + }; + let result = instantiate_obj(); + // do not cache the type if there are errors + if matches!(result, Err(_) | Ok(Err(_))) { + self.pyid_to_type.write().remove(&ty_id); } - return Ok(Ok(unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty))); + result } _ => Ok(Ok(extracted_ty)), };