forked from M-Labs/nac3
[artiq] symbol_resolver: Fix support for np.bool_ -> bool decay
This commit is contained in:
parent
8322d457c6
commit
d9c180ed13
nac3artiq
29
nac3artiq/demo/numpy_primitives_decay.py
Normal file
29
nac3artiq/demo/numpy_primitives_decay.py
Normal file
@ -0,0 +1,29 @@
|
||||
from min_artiq import *
|
||||
import numpy
|
||||
from numpy import int32
|
||||
|
||||
|
||||
@nac3
|
||||
class NumpyBoolDecay:
|
||||
core: KernelInvariant[Core]
|
||||
np_true: KernelInvariant[bool]
|
||||
np_false: KernelInvariant[bool]
|
||||
np_int: KernelInvariant[int32]
|
||||
np_float: KernelInvariant[float]
|
||||
np_str: KernelInvariant[str]
|
||||
|
||||
def __init__(self):
|
||||
self.core = Core()
|
||||
self.np_true = numpy.True_
|
||||
self.np_false = numpy.False_
|
||||
self.np_int = numpy.int32(0)
|
||||
self.np_float = numpy.float64(0.0)
|
||||
self.np_str = numpy.str_("")
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NumpyBoolDecay().run()
|
@ -931,10 +931,13 @@ impl InnerResolver {
|
||||
|_| Ok(Ok(extracted_ty)),
|
||||
)
|
||||
} else if unifier.unioned(extracted_ty, primitives.bool) {
|
||||
obj.extract::<bool>().map_or_else(
|
||||
|_| Ok(Err(format!("{obj} is not in the range of bool"))),
|
||||
|_| Ok(Ok(extracted_ty)),
|
||||
)
|
||||
if let Ok(_) = obj.extract::<bool>() {
|
||||
Ok(Ok(extracted_ty))
|
||||
} else if let Ok(_) = obj.call_method("__bool__", (), None)?.extract::<bool>() {
|
||||
Ok(Ok(extracted_ty))
|
||||
} else {
|
||||
Ok(Err(format!("{obj} is not in the range of bool")))
|
||||
}
|
||||
} else if unifier.unioned(extracted_ty, primitives.float) {
|
||||
obj.extract::<f64>().map_or_else(
|
||||
|_| Ok(Err(format!("{obj} is not in the range of float64"))),
|
||||
@ -974,10 +977,14 @@ impl InnerResolver {
|
||||
let val: u64 = obj.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val));
|
||||
Ok(Some(ctx.ctx.i64_type().const_int(val, false).into()))
|
||||
} else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ {
|
||||
} else if ty_id == self.primitive_ids.bool {
|
||||
let val: bool = obj.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
|
||||
Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
|
||||
} else if ty_id == self.primitive_ids.np_bool_ {
|
||||
let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
|
||||
Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
|
||||
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
||||
let val: String = obj.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone()));
|
||||
@ -1413,9 +1420,12 @@ impl InnerResolver {
|
||||
} else if ty_id == self.primitive_ids.uint64 {
|
||||
let val: u64 = obj.extract()?;
|
||||
Ok(SymbolValue::U64(val))
|
||||
} else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ {
|
||||
} else if ty_id == self.primitive_ids.bool {
|
||||
let val: bool = obj.extract()?;
|
||||
Ok(SymbolValue::Bool(val))
|
||||
} else if ty_id == self.primitive_ids.np_bool_ {
|
||||
let val: bool = obj.call_method("__bool__", (), None)?.extract()?;
|
||||
Ok(SymbolValue::Bool(val))
|
||||
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
||||
let val: String = obj.extract()?;
|
||||
Ok(SymbolValue::Str(val))
|
||||
|
Loading…
Reference in New Issue
Block a user