From d9c180ed13df48fcfcc2b6b148741c09732742bb Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Jan 2025 13:36:51 +0800 Subject: [PATCH] [artiq] symbol_resolver: Fix support for np.bool_ -> bool decay --- nac3artiq/demo/numpy_primitives_decay.py | 29 ++++++++++++++++++++++++ nac3artiq/src/symbol_resolver.rs | 22 +++++++++++++----- 2 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 nac3artiq/demo/numpy_primitives_decay.py diff --git a/nac3artiq/demo/numpy_primitives_decay.py b/nac3artiq/demo/numpy_primitives_decay.py new file mode 100644 index 000000000..957d363f3 --- /dev/null +++ b/nac3artiq/demo/numpy_primitives_decay.py @@ -0,0 +1,29 @@ +from min_artiq import * +import numpy +from numpy import int32 + + +@nac3 +class NumpyBoolDecay: + core: KernelInvariant[Core] + np_true: KernelInvariant[bool] + np_false: KernelInvariant[bool] + np_int: KernelInvariant[int32] + np_float: KernelInvariant[float] + np_str: KernelInvariant[str] + + def __init__(self): + self.core = Core() + self.np_true = numpy.True_ + self.np_false = numpy.False_ + self.np_int = numpy.int32(0) + self.np_float = numpy.float64(0.0) + self.np_str = numpy.str_("") + + @kernel + def run(self): + pass + + +if __name__ == "__main__": + NumpyBoolDecay().run() diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 8e9cd10cb..0f9d57bd0 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -931,10 +931,13 @@ impl InnerResolver { |_| Ok(Ok(extracted_ty)), ) } else if unifier.unioned(extracted_ty, primitives.bool) { - obj.extract::().map_or_else( - |_| Ok(Err(format!("{obj} is not in the range of bool"))), - |_| Ok(Ok(extracted_ty)), - ) + if let Ok(_) = obj.extract::() { + Ok(Ok(extracted_ty)) + } else if let Ok(_) = obj.call_method("__bool__", (), None)?.extract::() { + Ok(Ok(extracted_ty)) + } else { + Ok(Err(format!("{obj} is not in the range of bool"))) + } } else if unifier.unioned(extracted_ty, primitives.float) { obj.extract::().map_or_else( |_| Ok(Err(format!("{obj} is not in the range of float64"))), @@ -974,10 +977,14 @@ impl InnerResolver { let val: u64 = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); + Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract().unwrap(); self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); @@ -1413,9 +1420,12 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.uint64 { let val: u64 = obj.extract()?; Ok(SymbolValue::U64(val)) - } else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ { + } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract()?; Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract()?; + Ok(SymbolValue::Bool(val)) } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { let val: String = obj.extract()?; Ok(SymbolValue::Str(val))