From 2470ab7856e92d9970a72205f8ed2d2415e1ba03 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 13 Jun 2024 16:35:26 +0800 Subject: [PATCH] Add support for class objects and strings --- nac3artiq/demo/class_instantiation_test.py | 20 ++++++++++++ nac3artiq/demo/string_attribute_issue339.py | 24 ++++++++++++++ nac3artiq/src/lib.rs | 2 ++ nac3artiq/src/symbol_resolver.rs | 35 +++++++++++++++------ 4 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 nac3artiq/demo/class_instantiation_test.py create mode 100644 nac3artiq/demo/string_attribute_issue339.py diff --git a/nac3artiq/demo/class_instantiation_test.py b/nac3artiq/demo/class_instantiation_test.py new file mode 100644 index 00000000..21648456 --- /dev/null +++ b/nac3artiq/demo/class_instantiation_test.py @@ -0,0 +1,20 @@ +from min_artiq import * + +@nac3 +class Foo: + attr: Kernel[str] + @kernel + def __init__(self): + self.attr = "attr" + +@nac3 +class Bar: + core: KernelInvariant[Core] + def __init__(self): + self.core = Core() + @kernel + def run(self): + a = Foo() + +if __name__ == "__main__": + Bar().run() \ No newline at end of file diff --git a/nac3artiq/demo/string_attribute_issue339.py b/nac3artiq/demo/string_attribute_issue339.py new file mode 100644 index 00000000..e1f09b83 --- /dev/null +++ b/nac3artiq/demo/string_attribute_issue339.py @@ -0,0 +1,24 @@ +from min_artiq import * +from numpy import int32 + + +@nac3 +class Demo: + core: KernelInvariant[Core] + attr1: KernelInvariant[str] + attr2: KernelInvariant[int32] + + + def __init__(self): + self.core = Core() + self.attr2 = 32 + self.attr1 = "SAMPLE" + + @kernel + def run(self): + print_int32(self.attr2) + self.attr1 + + +if __name__ == "__main__": + Demo().run() \ No newline at end of file diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 04344e23..ebe2636c 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -102,6 +102,7 @@ pub struct PrimitivePythonId { float: u64, float64: u64, bool: u64, + string: u64, list: u64, ndarray: u64, tuple: u64, @@ -921,6 +922,7 @@ impl Nac3 { uint32: get_attr_id(numpy_mod, "uint32"), uint64: get_attr_id(numpy_mod, "uint64"), bool: get_attr_id(builtins_mod, "bool"), + string: get_attr_id(builtins_mod, "str"), float: get_attr_id(builtins_mod, "float"), float64: get_attr_id(numpy_mod, "float64"), list: get_attr_id(builtins_mod, "list"), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 16cb5cae..e9fba5b1 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -35,6 +35,7 @@ pub enum PrimitiveValue { U64(u64), F64(f64), Bool(bool), + Str(String), } /// An entry in the [`DeferredEvaluationStore`], containing the deferred types, a [`PyObject`] @@ -154,6 +155,7 @@ impl StaticValue for PythonValue { PrimitiveValue::Bool(val) => { ctx.ctx.i8_type().const_int(u64::from(*val), false).into() } + PrimitiveValue::Str(val) => ctx.ctx.const_string(val.as_bytes(), true).into(), }); } if let Some(global) = ctx.module.get_global(&self.id.to_string()) { @@ -300,7 +302,11 @@ impl InnerResolver { let ty_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; let ty_ty_id: u64 = self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (pyty,))?,))?.extract(py)?; - + let py_obj_id: u64 = self.helper.id_fn.call1(py, (pyty,))?.extract(py)?; + let get_def_id = || { + self.pyid_to_def.read().get(&ty_id).copied() + .or_else(|| self.pyid_to_def.read().get(&py_obj_id).copied()) + }; if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { Ok(Ok((primitives.int32, true))) } else if ty_id == self.primitive_ids.int64 { @@ -311,6 +317,8 @@ impl InnerResolver { Ok(Ok((primitives.uint64, true))) } else if ty_id == self.primitive_ids.bool { Ok(Ok((primitives.bool, true))) + } else if ty_id == self.primitive_ids.string { + Ok(Ok((primitives.str, true))) } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { Ok(Ok((primitives.float, true))) } else if ty_id == self.primitive_ids.exception { @@ -333,7 +341,7 @@ impl InnerResolver { Ok(Ok((primitives.option, false))) } else if ty_id == self.primitive_ids.none { unreachable!("none cannot be typeid") - } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { + } else if let Some(def_id) = get_def_id() { let def = defs[def_id.0].read(); let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def else { // only object is supported, functions are not supported @@ -599,12 +607,13 @@ impl InnerResolver { let pyid_to_def = self.pyid_to_def.read(); let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| { defs.iter().find_map(|def| { - if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() { - if object_id == def_id - && constructor.is_some() - && methods.iter().any(|(s, _, _)| s == &"__init__".into()) - { - return *constructor; + if let Some(rear_guard) = def.try_read(){ + if let TopLevelDef::Class { + object_id, methods, constructor, .. + } = &*rear_guard { + if object_id == def_id && constructor.is_some() && methods.iter().any(|(s, _, _)| s == &"__init__".into()) { + return *constructor; + } } } None @@ -624,7 +633,8 @@ impl InnerResolver { self.primitive_ids.generic_alias.0, self.primitive_ids.generic_alias.1, ] - .contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)?) + .contains(&self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)?) + || self.pyid_to_def.read().contains_key(&py_obj_id) { obj } else { @@ -881,6 +891,10 @@ impl InnerResolver { let val: f64 = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); Ok(Some(ctx.ctx.f64_type().const_float(val).into())) + } else if ty_id == self.primitive_ids.string { + let val: String = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); + Ok(Some(ctx.ctx.const_string(val.as_bytes(), true).into())) } else if ty_id == self.primitive_ids.list { let id_str = id.to_string(); @@ -1123,6 +1137,9 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.string { + let val:String = obj.extract()?; + Ok(SymbolValue::Str(val)) } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { let val: f64 = obj.extract()?; Ok(SymbolValue::Double(val))