From ce3e9bf4fed6f4702154d98673e0898762c5cec2 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 14 Jun 2024 10:22:33 +0800 Subject: [PATCH] nac3artiq: add support string attributes in classes --- nac3artiq/demo/string_attribute_issue337.py | 24 +++++++++++++++++++++ nac3artiq/src/lib.rs | 2 ++ nac3artiq/src/symbol_resolver.rs | 11 ++++++++++ 3 files changed, 37 insertions(+) create mode 100644 nac3artiq/demo/string_attribute_issue337.py diff --git a/nac3artiq/demo/string_attribute_issue337.py b/nac3artiq/demo/string_attribute_issue337.py new file mode 100644 index 0000000..9749462 --- /dev/null +++ b/nac3artiq/demo/string_attribute_issue337.py @@ -0,0 +1,24 @@ +from min_artiq import * +from numpy import int32 + + +@nac3 +class Demo: + core: KernelInvariant[Core] + attr1: KernelInvariant[str] + attr2: KernelInvariant[int32] + + + def __init__(self): + self.core = Core() + self.attr2 = 32 + self.attr1 = "SAMPLE" + + @kernel + def run(self): + print_int32(self.attr2) + self.attr1 + + +if __name__ == "__main__": + Demo().run() diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 04344e2..ebe2636 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -102,6 +102,7 @@ pub struct PrimitivePythonId { float: u64, float64: u64, bool: u64, + string: u64, list: u64, ndarray: u64, tuple: u64, @@ -921,6 +922,7 @@ impl Nac3 { uint32: get_attr_id(numpy_mod, "uint32"), uint64: get_attr_id(numpy_mod, "uint64"), bool: get_attr_id(builtins_mod, "bool"), + string: get_attr_id(builtins_mod, "str"), float: get_attr_id(builtins_mod, "float"), float64: get_attr_id(numpy_mod, "float64"), list: get_attr_id(builtins_mod, "list"), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0b9ede9..0ea37a9 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -35,6 +35,7 @@ pub enum PrimitiveValue { U64(u64), F64(f64), Bool(bool), + Str(String), } /// An entry in the [`DeferredEvaluationStore`], containing the deferred types, a [`PyObject`] @@ -154,6 +155,7 @@ impl StaticValue for PythonValue { PrimitiveValue::Bool(val) => { ctx.ctx.i8_type().const_int(u64::from(*val), false).into() } + PrimitiveValue::Str(val) => ctx.ctx.const_string(val.as_bytes(), true).into(), }); } if let Some(global) = ctx.module.get_global(&self.id.to_string()) { @@ -311,6 +313,8 @@ impl InnerResolver { Ok(Ok((primitives.uint64, true))) } else if ty_id == self.primitive_ids.bool { Ok(Ok((primitives.bool, true))) + } else if ty_id == self.primitive_ids.string { + Ok(Ok((primitives.str, true))) } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { Ok(Ok((primitives.float, true))) } else if ty_id == self.primitive_ids.exception { @@ -873,6 +877,10 @@ impl InnerResolver { let val: bool = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.string { + let val: String = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); + Ok(Some(ctx.ctx.const_string(val.as_bytes(), true).into())) } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { let val: f64 = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); @@ -1119,6 +1127,9 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.string { + let val: String = obj.extract()?; + Ok(SymbolValue::Str(val)) } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { let val: f64 = obj.extract()?; Ok(SymbolValue::Double(val))