diff --git a/nac3artiq/demo/class_instantiation_test.py b/nac3artiq/demo/class_instantiation_test.py new file mode 100644 index 0000000..7d8608a --- /dev/null +++ b/nac3artiq/demo/class_instantiation_test.py @@ -0,0 +1,26 @@ +from min_artiq import * +from numpy import int32, ceil + + +@nac3 +class Foo: + attr: KernelInvariant[int32] + +@nac3 +class Bar: + core: KernelInvariant[Core] + attr2: KernelInvariant[int32] + + def __init__(self): + self.core = Core() + self.attr2 = 4 + + @kernel + def run(self): + self.core.reset() + a: Foo = Foo() + + + +if __name__ == "__main__": + Bar().run() \ No newline at end of file diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0b9ede9..8d271da 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -300,6 +300,14 @@ impl InnerResolver { let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; let ty_ty_id: u64 = self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))?.extract(py)?; + let py_obj_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; + let get_def_id = || { + self.pyid_to_def + .read() + .get(&ty_id) + .copied() + .or_else(|| self.pyid_to_def.read().get(&py_obj_id).copied()) + }; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { Ok(Ok((primitives.int32, true))) @@ -333,7 +341,7 @@ impl InnerResolver { Ok(Ok((primitives.option, false))) } else if ty_id == self.primitive_ids.none { unreachable!("none cannot be typeid") - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { + } else if let Some(def_id) = get_def_id() { let def = defs[def_id.0].read(); let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def else { // only object is supported, functions are not supported @@ -599,12 +607,15 @@ impl InnerResolver { let pyid_to_def = self.pyid_to_def.read(); let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| { defs.iter().find_map(|def| { - if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() { - if object_id == def_id - && constructor.is_some() - && methods.iter().any(|(s, _, _)| s == &"__init__".into()) + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*rear_guard { - return *constructor; + if object_id == def_id + && constructor.is_some() + && methods.iter().any(|(s, _, _)| s == &"__init__".into()) + { + return *constructor; + } } } None @@ -625,6 +636,7 @@ impl InnerResolver { self.primitive_ids.generic_alias.1, ] .contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)?) + || self.pyid_to_def.read().contains_key(&py_obj_id) { obj } else {