From 1bc95a7ba613d3defe26dc32bd83cc91c46364c5 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 18 Jun 2024 14:14:12 +0800 Subject: [PATCH] Add handling for np.bool_ and np.str_ --- nac3artiq/src/lib.rs | 4 ++++ nac3artiq/src/symbol_resolver.rs | 12 ++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index ebe2636c..973ab9cc 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -102,7 +102,9 @@ pub struct PrimitivePythonId { float: u64, float64: u64, bool: u64, + np_bool_: u64, string: u64, + np_str_: u64, list: u64, ndarray: u64, tuple: u64, @@ -922,7 +924,9 @@ impl Nac3 { uint32: get_attr_id(numpy_mod, "uint32"), uint64: get_attr_id(numpy_mod, "uint64"), bool: get_attr_id(builtins_mod, "bool"), + np_bool_: get_attr_id(numpy_mod, "bool_"), string: get_attr_id(builtins_mod, "str"), + np_str_: get_attr_id(numpy_mod, "str_"), float: get_attr_id(builtins_mod, "float"), float64: get_attr_id(numpy_mod, "float64"), list: get_attr_id(builtins_mod, "list"), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0ea37a94..aa7f2a4d 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -311,9 +311,9 @@ impl InnerResolver { Ok(Ok((primitives.uint32, true))) } else if ty_id == self.primitive_ids.uint64 { Ok(Ok((primitives.uint64, true))) - } else if ty_id == self.primitive_ids.bool { + } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { Ok(Ok((primitives.bool, true))) - } else if ty_id == self.primitive_ids.string { + } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { Ok(Ok((primitives.str, true))) } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { Ok(Ok((primitives.float, true))) @@ -873,11 +873,11 @@ impl InnerResolver { let val: u64 = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) - } else if ty_id == self.primitive_ids.bool { + } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { let val: bool = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) - } else if ty_id == self.primitive_ids.string { + } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); Ok(Some(ctx.ctx.const_string(val.as_bytes(), true).into())) @@ -1124,10 +1124,10 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.uint64 { let val: u64 = obj.extract()?; Ok(SymbolValue::U64(val)) - } else if ty_id == self.primitive_ids.bool { + } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { let val: bool = obj.extract()?; Ok(SymbolValue::Bool(val)) - } else if ty_id == self.primitive_ids.string { + } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract()?; Ok(SymbolValue::Str(val)) } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 {