From d7e51b2b407cf221fa08cfd06fdc63c10a4ef2dd Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 14 Jun 2024 15:19:00 +0800 Subject: [PATCH 1/2] nac3artiq: add support for class objects without __init__ method --- nac3artiq/demo/class_instantiation_test.py | 26 ++++++++++++++++++++++ nac3artiq/src/symbol_resolver.rs | 24 +++++++++++++++----- 2 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 nac3artiq/demo/class_instantiation_test.py diff --git a/nac3artiq/demo/class_instantiation_test.py b/nac3artiq/demo/class_instantiation_test.py new file mode 100644 index 00000000..7d8608a8 --- /dev/null +++ b/nac3artiq/demo/class_instantiation_test.py @@ -0,0 +1,26 @@ +from min_artiq import * +from numpy import int32, ceil + + +@nac3 +class Foo: + attr: KernelInvariant[int32] + +@nac3 +class Bar: + core: KernelInvariant[Core] + attr2: KernelInvariant[int32] + + def __init__(self): + self.core = Core() + self.attr2 = 4 + + @kernel + def run(self): + self.core.reset() + a: Foo = Foo() + + + +if __name__ == "__main__": + Bar().run() \ No newline at end of file diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0b9ede90..8d271dac 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -300,6 +300,14 @@ impl InnerResolver { let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; let ty_ty_id: u64 = self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))?.extract(py)?; + let py_obj_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; + let get_def_id = || { + self.pyid_to_def + .read() + .get(&ty_id) + .copied() + .or_else(|| self.pyid_to_def.read().get(&py_obj_id).copied()) + }; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { Ok(Ok((primitives.int32, true))) @@ -333,7 +341,7 @@ impl InnerResolver { Ok(Ok((primitives.option, false))) } else if ty_id == self.primitive_ids.none { unreachable!("none cannot be typeid") - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { + } else if let Some(def_id) = get_def_id() { let def = defs[def_id.0].read(); let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def else { // only object is supported, functions are not supported @@ -599,12 +607,15 @@ impl InnerResolver { let pyid_to_def = self.pyid_to_def.read(); let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| { defs.iter().find_map(|def| { - if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() { - if object_id == def_id - && constructor.is_some() - && methods.iter().any(|(s, _, _)| s == &"__init__".into()) + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*rear_guard { - return *constructor; + if object_id == def_id + && constructor.is_some() + && methods.iter().any(|(s, _, _)| s == &"__init__".into()) + { + return *constructor; + } } } None @@ -625,6 +636,7 @@ impl InnerResolver { self.primitive_ids.generic_alias.1, ] .contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)?) + || self.pyid_to_def.read().contains_key(&py_obj_id) { obj } else { -- 2.44.2 From 919e247bd814c16a984b03628a38ffce074be76a Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 14 Jun 2024 15:42:10 +0800 Subject: [PATCH 2/2] nac3artiq/symbol_resolver: remove redundant declaration --- nac3artiq/src/symbol_resolver.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 8d271dac..389ea406 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -300,14 +300,6 @@ impl InnerResolver { let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; let ty_ty_id: u64 = self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))?.extract(py)?; - let py_obj_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; - let get_def_id = || { - self.pyid_to_def - .read() - .get(&ty_id) - .copied() - .or_else(|| self.pyid_to_def.read().get(&py_obj_id).copied()) - }; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { Ok(Ok((primitives.int32, true))) @@ -341,7 +333,14 @@ impl InnerResolver { Ok(Ok((primitives.option, false))) } else if ty_id == self.primitive_ids.none { unreachable!("none cannot be typeid") - } else if let Some(def_id) = get_def_id() { + } else if let Some(def_id) = { + let py_obj_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; + self.pyid_to_def + .read() + .get(&ty_id) + .copied() + .or_else(|| self.pyid_to_def.read().get(&py_obj_id).copied()) + } { let def = defs[def_id.0].read(); let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def else { // only object is supported, functions are not supported -- 2.44.2