From 19d183ed845c049663f5dc61cc8852ebf588c312 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 5 Jul 2024 12:54:15 +0800 Subject: [PATCH] artiq: Update to pyo3 v0.21 With the extensive use of as_gil_ref. Will have to refactor those away as well. --- nac3artiq/Cargo.toml | 2 +- nac3artiq/src/codegen.rs | 20 +++--- nac3artiq/src/lib.rs | 110 ++++++++++++++++--------------- nac3artiq/src/symbol_resolver.rs | 60 ++++++++++------- 4 files changed, 108 insertions(+), 84 deletions(-) diff --git a/nac3artiq/Cargo.toml b/nac3artiq/Cargo.toml index 4e0dd08..24fa221 100644 --- a/nac3artiq/Cargo.toml +++ b/nac3artiq/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] itertools = "0.13" -pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] } +pyo3 = { version = "0.21", features = ["extension-module"] } parking_lot = "0.12" tempfile = "3.10" nac3core = { path = "../nac3core" } diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index d16177f..76f5eb2 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -26,8 +26,8 @@ use nac3core::inkwell::{ }; use pyo3::{ + prelude::*, types::{PyDict, PyList}, - PyObject, PyResult, Python, }; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; @@ -970,7 +970,7 @@ pub fn attributes_writeback( host_attributes: &PyObject, ) -> Result<(), String> { Python::with_gil(|py| -> PyResult> { - let host_attributes: &PyList = host_attributes.downcast(py)?; + let host_attributes = host_attributes.downcast_bound::(py)?; let top_levels = ctx.top_level.definitions.read(); let globals = inner_resolver.global_value_ids.read(); let int32 = ctx.ctx.i32_type(); @@ -978,10 +978,10 @@ pub fn attributes_writeback( let mut values = Vec::new(); let mut scratch_buffer = Vec::new(); for val in (*globals).values() { - let val = val.as_ref(py); + let val = val.bind_borrowed(py); let ty = inner_resolver.get_obj_type( py, - val, + val.as_gil_ref(), &mut ctx.unifier, &top_levels, &ctx.primitives, @@ -997,7 +997,9 @@ pub fn attributes_writeback( // we only care about primitive attributes // for non-primitive attributes, they should be in another global let mut attributes = Vec::new(); - let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); + let obj = inner_resolver + .get_obj_value(py, val.as_gil_ref(), ctx, generator, ty)? + .unwrap(); for (name, (field_ty, is_mutable)) in fields { if !is_mutable { continue; @@ -1016,7 +1018,7 @@ pub fn attributes_writeback( } } if !attributes.is_empty() { - let pydict = PyDict::new(py); + let pydict = PyDict::new_bound(py); pydict.set_item("obj", val)?; pydict.set_item("fields", attributes)?; host_attributes.append(pydict)?; @@ -1026,12 +1028,14 @@ pub fn attributes_writeback( let elem_ty = iter_type_vars(params).next().unwrap().ty; if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() { - let pydict = PyDict::new(py); + let pydict = PyDict::new_bound(py); pydict.set_item("obj", val)?; host_attributes.append(pydict)?; values.push(( ty, - inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(), + inner_resolver + .get_obj_value(py, val.as_gil_ref(), ctx, generator, ty)? + .unwrap(), )); } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 9675efb..7750b91 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -40,9 +40,11 @@ use nac3core::nac3parser::{ }; use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap}; -use pyo3::create_exception; -use pyo3::prelude::*; -use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet}; +use pyo3::{ + create_exception, exceptions, + prelude::*, + types::{PyBytes, PyDict, PySet}, +}; use parking_lot::{Mutex, RwLock}; @@ -174,7 +176,7 @@ impl Nac3 { // Drop unregistered (i.e. host-only) base classes. bases.retain(|base| { Python::with_gil(|py| -> PyResult { - let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; + let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?; match &base.node { ExprKind::Name { id, .. } => { if *id == "Exception".into() { @@ -361,10 +363,10 @@ impl Nac3 { fn compile_method( &self, - obj: &PyAny, + obj: &Bound, method_name: &str, args: Vec<&PyAny>, - embedding_map: &PyAny, + embedding_map: &Bound, py: Python, link_fn: &dyn Fn(&Module) -> PyResult, ) -> PyResult { @@ -376,8 +378,8 @@ impl Nac3 { size_t, ); - let builtins = PyModule::import(py, "builtins")?; - let typings = PyModule::import(py, "typing")?; + let builtins = PyModule::import_bound(py, "builtins")?; + let typings = PyModule::import_bound(py, "typing")?; let id_fn = builtins.getattr("id")?; let issubclass = builtins.getattr("issubclass")?; let exn_class = builtins.getattr("Exception")?; @@ -421,7 +423,7 @@ impl Nac3 { let class_obj; if let StmtKind::ClassDef { name, .. } = &stmt.node { let class = py_module.getattr(name.to_string().as_str()).unwrap(); - if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() + if issubclass.call1((class, exn_class.as_gil_ref())).unwrap().extract().unwrap() && class.getattr("artiq_builtin").is_err() { class_obj = Some(class); @@ -513,15 +515,15 @@ impl Nac3 { } } - let id_fun = PyModule::import(py, "builtins")?.getattr("id")?; + let id_fun = PyModule::import_bound(py, "builtins")?.getattr("id")?; let mut name_to_pyid: HashMap = HashMap::new(); - let module = PyModule::new(py, "tmp")?; + let module = PyModule::new_bound(py, "tmp")?; module.add("base", obj)?; name_to_pyid.insert("base".into(), id_fun.call1((obj,))?.extract()?); let mut arg_names = vec![]; for (i, arg) in args.into_iter().enumerate() { let name = format!("tmp{i}"); - module.add(&name, arg)?; + module.add(&*name, arg)?; name_to_pyid.insert(name.clone().into(), id_fun.call1((arg,))?.extract()?); arg_names.push(name); } @@ -900,7 +902,7 @@ fn add_exceptions( #[pymethods] impl Nac3 { #[new] - fn new(isa: &str, artiq_builtins: &PyDict, py: Python) -> PyResult { + fn new(isa: &str, artiq_builtins: &Bound, py: Python) -> PyResult { let isa = match isa { "host" => Isa::Host, "rv32g" => Isa::RiscV32G, @@ -964,43 +966,45 @@ impl Nac3 { ), ]; - let builtins_mod = PyModule::import(py, "builtins").unwrap(); + let builtins_mod = PyModule::import_bound(py, "builtins").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap(); - let numpy_mod = PyModule::import(py, "numpy").unwrap(); - let typing_mod = PyModule::import(py, "typing").unwrap(); - let types_mod = PyModule::import(py, "types").unwrap(); + let numpy_mod = PyModule::import_bound(py, "numpy").unwrap(); + let typing_mod = PyModule::import_bound(py, "typing").unwrap(); + let types_mod = PyModule::import_bound(py, "types").unwrap(); - let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap(); + let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(|id| id.extract()).unwrap(); let get_attr_id = |obj: &PyModule, attr| { id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap() }; let primitive_ids = PrimitivePythonId { - virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()), + virtual_id: get_id( + artiq_builtins.get_item("virtual").ok().flatten().unwrap().as_gil_ref(), + ), generic_alias: ( - get_attr_id(typing_mod, "_GenericAlias"), - get_attr_id(types_mod, "GenericAlias"), + get_attr_id(typing_mod.as_gil_ref(), "_GenericAlias"), + get_attr_id(types_mod.as_gil_ref(), "GenericAlias"), ), - none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()), - typevar: get_attr_id(typing_mod, "TypeVar"), + none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap().as_gil_ref()), + typevar: get_attr_id(typing_mod.as_gil_ref(), "TypeVar"), const_generic_marker: get_id( - artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(), + artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap().as_gil_ref(), ), - int: get_attr_id(builtins_mod, "int"), - int32: get_attr_id(numpy_mod, "int32"), - int64: get_attr_id(numpy_mod, "int64"), - uint32: get_attr_id(numpy_mod, "uint32"), - uint64: get_attr_id(numpy_mod, "uint64"), - bool: get_attr_id(builtins_mod, "bool"), - np_bool_: get_attr_id(numpy_mod, "bool_"), - string: get_attr_id(builtins_mod, "str"), - np_str_: get_attr_id(numpy_mod, "str_"), - float: get_attr_id(builtins_mod, "float"), - float64: get_attr_id(numpy_mod, "float64"), - list: get_attr_id(builtins_mod, "list"), - ndarray: get_attr_id(numpy_mod, "ndarray"), - tuple: get_attr_id(builtins_mod, "tuple"), - exception: get_attr_id(builtins_mod, "Exception"), - option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), + int: get_attr_id(builtins_mod.as_gil_ref(), "int"), + int32: get_attr_id(numpy_mod.as_gil_ref(), "int32"), + int64: get_attr_id(numpy_mod.as_gil_ref(), "int64"), + uint32: get_attr_id(numpy_mod.as_gil_ref(), "uint32"), + uint64: get_attr_id(numpy_mod.as_gil_ref(), "uint64"), + bool: get_attr_id(builtins_mod.as_gil_ref(), "bool"), + np_bool_: get_attr_id(numpy_mod.as_gil_ref(), "bool_"), + string: get_attr_id(builtins_mod.as_gil_ref(), "str"), + np_str_: get_attr_id(numpy_mod.as_gil_ref(), "str_"), + float: get_attr_id(builtins_mod.as_gil_ref(), "float"), + float64: get_attr_id(numpy_mod.as_gil_ref(), "float64"), + list: get_attr_id(builtins_mod.as_gil_ref(), "list"), + ndarray: get_attr_id(numpy_mod.as_gil_ref(), "ndarray"), + tuple: get_attr_id(builtins_mod.as_gil_ref(), "tuple"), + exception: get_attr_id(builtins_mod.as_gil_ref(), "Exception"), + option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap().as_gil_ref()), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); @@ -1025,23 +1029,23 @@ impl Nac3 { }) } - fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { + fn analyze(&mut self, functions: &Bound, classes: &Bound) -> PyResult<()> { let (modules, class_ids) = Python::with_gil(|py| -> PyResult<(HashMap, HashSet)> { let mut modules: HashMap = HashMap::new(); let mut class_ids: HashSet = HashSet::new(); - let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; - let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; + let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?; + let getmodule_fn = PyModule::import_bound(py, "inspect")?.getattr("getmodule")?; for function in functions { let module = getmodule_fn.call1((function,))?.extract()?; modules.insert(id_fn.call1((&module,))?.extract()?, module); } for class in classes { - let module = getmodule_fn.call1((class,))?.extract()?; + let module = getmodule_fn.call1((class.as_gil_ref(),))?.extract()?; modules.insert(id_fn.call1((&module,))?.extract()?, module); - class_ids.insert(id_fn.call1((class,))?.extract()?); + class_ids.insert(id_fn.call1((class.as_gil_ref(),))?.extract()?); } Ok((modules, class_ids)) })?; @@ -1054,11 +1058,11 @@ impl Nac3 { fn compile_method_to_file( &mut self, - obj: &PyAny, + obj: &Bound, method_name: &str, args: Vec<&PyAny>, filename: &str, - embedding_map: &PyAny, + embedding_map: &Bound, py: Python, ) -> PyResult<()> { let target_machine = self.get_llvm_target_machine(); @@ -1100,10 +1104,10 @@ impl Nac3 { fn compile_method_to_mem( &mut self, - obj: &PyAny, + obj: &Bound, method_name: &str, args: Vec<&PyAny>, - embedding_map: &PyAny, + embedding_map: &Bound, py: Python, ) -> PyResult { let target_machine = self.get_llvm_target_machine(); @@ -1122,7 +1126,7 @@ impl Nac3 { working_directory.join("module.o").to_string_lossy().to_string(), )?; - Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into()) + Ok(PyBytes::new_bound(py, &fs::read(filename).unwrap()).into()) }; self.compile_method(obj, method_name, args, embedding_map, py, &link_fn) @@ -1132,7 +1136,7 @@ impl Nac3 { .write_to_memory_buffer(module, FileType::Object) .expect("couldn't write module to object file buffer"); if let Ok(dyn_lib) = Linker::ld(object_mem.as_slice()) { - Ok(PyBytes::new(py, &dyn_lib).into()) + Ok(PyBytes::new_bound(py, &dyn_lib).into()) } else { Err(CompileError::new_err("linker failed to process object file")) } @@ -1149,14 +1153,14 @@ extern "C" { } #[pymodule] -fn nac3artiq(py: Python, m: &PyModule) -> PyResult<()> { +fn nac3artiq(py: Python, m: &Bound) -> PyResult<()> { #[cfg(feature = "init-llvm-profile")] unsafe { __llvm_profile_initialize(); } Target::initialize_all(&InitializationConfig::default()); - m.add("CompileError", py.get_type::())?; + m.add("CompileError", py.get_type_bound::())?; m.add_class::()?; Ok(()) } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index d3381d1..721e913 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -25,8 +25,8 @@ use nac3core::{ }; use parking_lot::RwLock; use pyo3::{ + prelude::*, types::{PyDict, PyTuple}, - PyAny, PyObject, PyResult, Python, }; use std::{ collections::{HashMap, HashSet}, @@ -173,7 +173,7 @@ impl StaticValue for PythonValue { Python::with_gil(|py| -> PyResult> { self.resolver - .get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty) + .get_obj_value(py, self.value.bind(py).as_gil_ref(), ctx, generator, expected_ty) .map(Option::unwrap) }) .map_err(|e| e.to_string()) @@ -461,22 +461,27 @@ impl InnerResolver { { let origin = self.helper.origin_ty_fn.call1(py, (pyty,))?; let args = self.helper.args_ty_fn.call1(py, (pyty,))?; - let args: &PyTuple = args.downcast(py)?; - let origin_ty = - match self.get_pyty_obj_type(py, origin.as_ref(py), unifier, defs, primitives)? { - Ok((ty, false)) => ty, - Ok((_, true)) => { - return Ok(Err("instantiated type does not take type parameters".into())) - } - Err(err) => return Ok(Err(err)), - }; + let args = args.downcast_bound::(py)?; + let origin_ty = match self.get_pyty_obj_type( + py, + origin.bind(py).as_gil_ref(), + unifier, + defs, + primitives, + )? { + Ok((ty, false)) => ty, + Ok((_, true)) => { + return Ok(Err("instantiated type does not take type parameters".into())) + } + Err(err) => return Ok(Err(err)), + }; match &*unifier.get_ty(origin_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { if args.len() == 1 { let ty = match self.get_pyty_obj_type( py, - args.get_item(0)?, + args.get_item(0)?.as_gil_ref(), unifier, defs, primitives, @@ -522,9 +527,15 @@ impl InnerResolver { // npt.NDArray[T] == np.ndarray[Any, np.dtype[T]] let ndarray_dtype_pyty = self.helper.args_ty_fn.call1(py, (args.get_item(1)?,))?; - let dtype = ndarray_dtype_pyty.downcast::(py)?.get_item(0)?; + let dtype = ndarray_dtype_pyty.downcast_bound::(py)?.get_item(0)?; - let ty = match self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)? { + let ty = match self.get_pyty_obj_type( + py, + dtype.as_gil_ref(), + unifier, + defs, + primitives, + )? { Ok(ty) => ty, Err(err) => return Ok(Err(err)), }; @@ -540,7 +551,7 @@ impl InnerResolver { TypeEnum::TTuple { .. } => { let args = match args .iter() - .map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives)) + .map(|x| self.get_pyty_obj_type(py, x.as_gil_ref(), unifier, defs, primitives)) .collect::, _>>()? .into_iter() .collect::, _>>() { @@ -573,7 +584,7 @@ impl InnerResolver { } let args = match args .iter() - .map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives)) + .map(|x| self.get_pyty_obj_type(py, x.as_gil_ref(), unifier, defs, primitives)) .collect::, _>>()? .into_iter() .collect::, _>>() { @@ -600,7 +611,7 @@ impl InnerResolver { if args.len() == 1 { let ty = match self.get_pyty_obj_type( py, - args.get_item(0)?, + args.get_item(0)?.as_gil_ref(), unifier, defs, primitives, @@ -631,8 +642,7 @@ impl InnerResolver { false, ))) } else { - let str_fn = - pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); + let str_fn = PyModule::import_bound(py, "builtins").unwrap().getattr("repr").unwrap(); let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap(); Ok(Err(format!("{str_repr} is not registered with NAC3 (@nac3 decorator missing?)"))) } @@ -688,7 +698,7 @@ impl InnerResolver { { obj } else { - ty.as_ref(py) + ty.bind(py).as_gil_ref() } }, unifier, @@ -1549,10 +1559,16 @@ impl SymbolResolver for Resolver { let store = self.0.deferred_eval_store.store.read(); Python::with_gil(|py| -> PyResult> { for (variables, constraints, name) in store.iter() { - let constraints: &PyAny = constraints.as_ref(py); + let constraints = constraints.bind(py); for (i, var) in variables.iter().enumerate() { if let Ok(constr) = constraints.get_item(i) { - match self.0.get_pyty_obj_type(py, constr, unifier, defs, primitives)? { + match self.0.get_pyty_obj_type( + py, + constr.as_gil_ref(), + unifier, + defs, + primitives, + )? { Ok((ty, _)) => { if !unifier.is_concrete(ty, &[]) { return Ok(Err(format!(