From ba8ed6c663e61bd04109996c1faab6793e310ab1 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Wed, 23 Mar 2022 03:38:48 +0800 Subject: [PATCH] nac3artiq: handle recursive types properly --- nac3artiq/src/symbol_resolver.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index eb123d75..e177b1ff 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -512,8 +512,8 @@ impl InnerResolver { primitives: &PrimitiveStore, ) -> PyResult> { let ty = self.helper.type_fn.call1(py, (obj,)).unwrap(); - let ty_id: u64 = self.helper.id_fn.call1(py, (ty.clone(),))?.extract(py)?; - if let Some(ty) = self.pyid_to_type.read().get(&ty_id) { + let py_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; + if let Some(ty) = self.pyid_to_type.read().get(&py_obj_id) { return Ok(Ok(*ty)) } let (extracted_ty, inst_check) = match self.get_pyty_obj_type( @@ -620,6 +620,7 @@ impl InnerResolver { Ok(Ok(res)) } (TypeEnum::TObj { params, fields, .. }, false) => { + self.pyid_to_type.write().insert(py_obj_id, extracted_ty); let var_map = params .iter() .map(|(id_var, ty)| { @@ -673,7 +674,13 @@ impl InnerResolver { let extracted_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty); Ok(Ok(extracted_ty)) }; - instantiate_obj() + let result = instantiate_obj(); + // update/remove the cache according to the result + match result { + Ok(Ok(ty)) => self.pyid_to_type.write().insert(py_obj_id, ty), + _ => self.pyid_to_type.write().remove(&py_obj_id) + }; + result } _ => Ok(Ok(extracted_ty)), } @@ -918,7 +925,9 @@ impl InnerResolver { let values = values?; if let Some(values) = values { let val = ty.const_named_struct(&values); - let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str) + }); global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) } else { -- 2.44.2