artiq: Update to pyo3 v0.21

With the extensive use of as_gil_ref. Will have to refactor those away
as well.
This commit is contained in:
David Mak 2024-07-05 12:54:15 +08:00
parent 0fc26df29e
commit e6926656a1
4 changed files with 108 additions and 84 deletions

View File

@ -10,7 +10,7 @@ crate-type = ["cdylib"]
[dependencies]
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
pyo3 = { version = "0.21", features = ["extension-module"] }
parking_lot = "0.12"
tempfile = "3.10"
nac3parser = { path = "../nac3parser" }

View File

@ -26,8 +26,8 @@ use inkwell::{
};
use pyo3::{
prelude::*,
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
@ -730,7 +730,7 @@ pub fn attributes_writeback(
host_attributes: &PyObject,
) -> Result<(), String> {
Python::with_gil(|py| -> PyResult<Result<(), String>> {
let host_attributes: &PyList = host_attributes.downcast(py)?;
let host_attributes = host_attributes.downcast_bound::<PyList>(py)?;
let top_levels = ctx.top_level.definitions.read();
let globals = inner_resolver.global_value_ids.read();
let int32 = ctx.ctx.i32_type();
@ -738,10 +738,10 @@ pub fn attributes_writeback(
let mut values = Vec::new();
let mut scratch_buffer = Vec::new();
for val in (*globals).values() {
let val = val.as_ref(py);
let val = val.bind_borrowed(py);
let ty = inner_resolver.get_obj_type(
py,
val,
val.as_gil_ref(),
&mut ctx.unifier,
&top_levels,
&ctx.primitives,
@ -757,7 +757,9 @@ pub fn attributes_writeback(
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global
let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
let obj = inner_resolver
.get_obj_value(py, val.as_gil_ref(), ctx, generator, ty)?
.unwrap();
for (name, (field_ty, is_mutable)) in fields {
if !is_mutable {
continue;
@ -776,7 +778,7 @@ pub fn attributes_writeback(
}
}
if !attributes.is_empty() {
let pydict = PyDict::new(py);
let pydict = PyDict::new_bound(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?;
@ -786,12 +788,14 @@ pub fn attributes_writeback(
let elem_ty = iter_type_vars(params).next().unwrap().ty;
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py);
let pydict = PyDict::new_bound(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((
ty,
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
inner_resolver
.get_obj_value(py, val.as_gil_ref(), ctx, generator, ty)?
.unwrap(),
));
}
}

View File

@ -40,9 +40,11 @@ use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::{
create_exception, exceptions,
prelude::*,
types::{PyBytes, PyDict, PySet},
};
use parking_lot::{Mutex, RwLock};
@ -174,7 +176,7 @@ impl Nac3 {
// Drop unregistered (i.e. host-only) base classes.
bases.retain(|base| {
Python::with_gil(|py| -> PyResult<bool> {
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?;
match &base.node {
ExprKind::Name { id, .. } => {
if *id == "Exception".into() {
@ -361,10 +363,10 @@ impl Nac3 {
fn compile_method<T>(
&self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
embedding_map: &PyAny,
embedding_map: &Bound<PyAny>,
py: Python,
link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> {
@ -376,8 +378,8 @@ impl Nac3 {
size_t,
);
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let builtins = PyModule::import_bound(py, "builtins")?;
let typings = PyModule::import_bound(py, "typing")?;
let id_fn = builtins.getattr("id")?;
let issubclass = builtins.getattr("issubclass")?;
let exn_class = builtins.getattr("Exception")?;
@ -421,7 +423,7 @@ impl Nac3 {
let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
if issubclass.call1((class, exn_class.as_gil_ref())).unwrap().extract().unwrap()
&& class.getattr("artiq_builtin").is_err()
{
class_obj = Some(class);
@ -514,15 +516,15 @@ impl Nac3 {
}
}
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
let id_fun = PyModule::import_bound(py, "builtins")?.getattr("id")?;
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let module = PyModule::new(py, "tmp")?;
let module = PyModule::new_bound(py, "tmp")?;
module.add("base", obj)?;
name_to_pyid.insert("base".into(), id_fun.call1((obj,))?.extract()?);
let mut arg_names = vec![];
for (i, arg) in args.into_iter().enumerate() {
let name = format!("tmp{i}");
module.add(&name, arg)?;
module.add(&*name, arg)?;
name_to_pyid.insert(name.clone().into(), id_fun.call1((arg,))?.extract()?);
arg_names.push(name);
}
@ -899,7 +901,7 @@ fn add_exceptions(
#[pymethods]
impl Nac3 {
#[new]
fn new(isa: &str, artiq_builtins: &PyDict, py: Python) -> PyResult<Self> {
fn new(isa: &str, artiq_builtins: &Bound<PyDict>, py: Python) -> PyResult<Self> {
let isa = match isa {
"host" => Isa::Host,
"rv32g" => Isa::RiscV32G,
@ -963,43 +965,45 @@ impl Nac3 {
),
];
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let builtins_mod = PyModule::import_bound(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
let numpy_mod = PyModule::import_bound(py, "numpy").unwrap();
let typing_mod = PyModule::import_bound(py, "typing").unwrap();
let types_mod = PyModule::import_bound(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(|id| id.extract()).unwrap();
let get_attr_id = |obj: &PyModule, attr| {
id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
virtual_id: get_id(
artiq_builtins.get_item("virtual").ok().flatten().unwrap().as_gil_ref(),
),
generic_alias: (
get_attr_id(typing_mod, "_GenericAlias"),
get_attr_id(types_mod, "GenericAlias"),
get_attr_id(typing_mod.as_gil_ref(), "_GenericAlias"),
get_attr_id(types_mod.as_gil_ref(), "GenericAlias"),
),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap().as_gil_ref()),
typevar: get_attr_id(typing_mod.as_gil_ref(), "TypeVar"),
const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap().as_gil_ref(),
),
int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"),
uint32: get_attr_id(numpy_mod, "uint32"),
uint64: get_attr_id(numpy_mod, "uint64"),
bool: get_attr_id(builtins_mod, "bool"),
np_bool_: get_attr_id(numpy_mod, "bool_"),
string: get_attr_id(builtins_mod, "str"),
np_str_: get_attr_id(numpy_mod, "str_"),
float: get_attr_id(builtins_mod, "float"),
float64: get_attr_id(numpy_mod, "float64"),
list: get_attr_id(builtins_mod, "list"),
ndarray: get_attr_id(numpy_mod, "ndarray"),
tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
int: get_attr_id(builtins_mod.as_gil_ref(), "int"),
int32: get_attr_id(numpy_mod.as_gil_ref(), "int32"),
int64: get_attr_id(numpy_mod.as_gil_ref(), "int64"),
uint32: get_attr_id(numpy_mod.as_gil_ref(), "uint32"),
uint64: get_attr_id(numpy_mod.as_gil_ref(), "uint64"),
bool: get_attr_id(builtins_mod.as_gil_ref(), "bool"),
np_bool_: get_attr_id(numpy_mod.as_gil_ref(), "bool_"),
string: get_attr_id(builtins_mod.as_gil_ref(), "str"),
np_str_: get_attr_id(numpy_mod.as_gil_ref(), "str_"),
float: get_attr_id(builtins_mod.as_gil_ref(), "float"),
float64: get_attr_id(numpy_mod.as_gil_ref(), "float64"),
list: get_attr_id(builtins_mod.as_gil_ref(), "list"),
ndarray: get_attr_id(numpy_mod.as_gil_ref(), "ndarray"),
tuple: get_attr_id(builtins_mod.as_gil_ref(), "tuple"),
exception: get_attr_id(builtins_mod.as_gil_ref(), "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap().as_gil_ref()),
};
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
@ -1024,23 +1028,23 @@ impl Nac3 {
})
}
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
fn analyze(&mut self, functions: &Bound<PySet>, classes: &Bound<PySet>) -> PyResult<()> {
let (modules, class_ids) =
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new();
let mut class_ids: HashSet<u64> = HashSet::new();
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import_bound(py, "inspect")?.getattr("getmodule")?;
for function in functions {
let module = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
for class in classes {
let module = getmodule_fn.call1((class,))?.extract()?;
let module = getmodule_fn.call1((class.as_gil_ref(),))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
class_ids.insert(id_fn.call1((class,))?.extract()?);
class_ids.insert(id_fn.call1((class.as_gil_ref(),))?.extract()?);
}
Ok((modules, class_ids))
})?;
@ -1053,11 +1057,11 @@ impl Nac3 {
fn compile_method_to_file(
&mut self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
filename: &str,
embedding_map: &PyAny,
embedding_map: &Bound<PyAny>,
py: Python,
) -> PyResult<()> {
let target_machine = self.get_llvm_target_machine();
@ -1099,10 +1103,10 @@ impl Nac3 {
fn compile_method_to_mem(
&mut self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
embedding_map: &PyAny,
embedding_map: &Bound<PyAny>,
py: Python,
) -> PyResult<PyObject> {
let target_machine = self.get_llvm_target_machine();
@ -1121,7 +1125,7 @@ impl Nac3 {
working_directory.join("module.o").to_string_lossy().to_string(),
)?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
Ok(PyBytes::new_bound(py, &fs::read(filename).unwrap()).into())
};
self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
@ -1131,7 +1135,7 @@ impl Nac3 {
.write_to_memory_buffer(module, FileType::Object)
.expect("couldn't write module to object file buffer");
if let Ok(dyn_lib) = Linker::ld(object_mem.as_slice()) {
Ok(PyBytes::new(py, &dyn_lib).into())
Ok(PyBytes::new_bound(py, &dyn_lib).into())
} else {
Err(CompileError::new_err("linker failed to process object file"))
}
@ -1148,14 +1152,14 @@ extern "C" {
}
#[pymodule]
fn nac3artiq(py: Python, m: &PyModule) -> PyResult<()> {
fn nac3artiq(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
#[cfg(feature = "init-llvm-profile")]
unsafe {
__llvm_profile_initialize();
}
Target::initialize_all(&InitializationConfig::default());
m.add("CompileError", py.get_type::<CompileError>())?;
m.add("CompileError", py.get_type_bound::<CompileError>())?;
m.add_class::<Nac3>()?;
Ok(())
}

View File

@ -25,8 +25,8 @@ use nac3core::{
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use pyo3::{
prelude::*,
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use std::{
collections::{HashMap, HashSet},
@ -174,7 +174,7 @@ impl StaticValue for PythonValue {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
self.resolver
.get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty)
.get_obj_value(py, self.value.bind(py).as_gil_ref(), ctx, generator, expected_ty)
.map(Option::unwrap)
})
.map_err(|e| e.to_string())
@ -462,22 +462,27 @@ impl InnerResolver {
{
let origin = self.helper.origin_ty_fn.call1(py, (pyty,))?;
let args = self.helper.args_ty_fn.call1(py, (pyty,))?;
let args: &PyTuple = args.downcast(py)?;
let origin_ty =
match self.get_pyty_obj_type(py, origin.as_ref(py), unifier, defs, primitives)? {
Ok((ty, false)) => ty,
Ok((_, true)) => {
return Ok(Err("instantiated type does not take type parameters".into()))
}
Err(err) => return Ok(Err(err)),
};
let args = args.downcast_bound::<PyTuple>(py)?;
let origin_ty = match self.get_pyty_obj_type(
py,
origin.bind(py).as_gil_ref(),
unifier,
defs,
primitives,
)? {
Ok((ty, false)) => ty,
Ok((_, true)) => {
return Ok(Err("instantiated type does not take type parameters".into()))
}
Err(err) => return Ok(Err(err)),
};
match &*unifier.get_ty(origin_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(
py,
args.get_item(0)?,
args.get_item(0)?.as_gil_ref(),
unifier,
defs,
primitives,
@ -523,9 +528,15 @@ impl InnerResolver {
// npt.NDArray[T] == np.ndarray[Any, np.dtype[T]]
let ndarray_dtype_pyty =
self.helper.args_ty_fn.call1(py, (args.get_item(1)?,))?;
let dtype = ndarray_dtype_pyty.downcast::<PyTuple>(py)?.get_item(0)?;
let dtype = ndarray_dtype_pyty.downcast_bound::<PyTuple>(py)?.get_item(0)?;
let ty = match self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)? {
let ty = match self.get_pyty_obj_type(
py,
dtype.as_gil_ref(),
unifier,
defs,
primitives,
)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err)),
};
@ -541,7 +552,7 @@ impl InnerResolver {
TypeEnum::TTuple { .. } => {
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives))
.map(|x| self.get_pyty_obj_type(py, x.as_gil_ref(), unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
@ -574,7 +585,7 @@ impl InnerResolver {
}
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives))
.map(|x| self.get_pyty_obj_type(py, x.as_gil_ref(), unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
@ -601,7 +612,7 @@ impl InnerResolver {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(
py,
args.get_item(0)?,
args.get_item(0)?.as_gil_ref(),
unifier,
defs,
primitives,
@ -632,8 +643,7 @@ impl InnerResolver {
false,
)))
} else {
let str_fn =
pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap();
let str_fn = PyModule::import_bound(py, "builtins").unwrap().getattr("repr").unwrap();
let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap();
Ok(Err(format!("{str_repr} is not registered with NAC3 (@nac3 decorator missing?)")))
}
@ -689,7 +699,7 @@ impl InnerResolver {
{
obj
} else {
ty.as_ref(py)
ty.bind(py).as_gil_ref()
}
},
unifier,
@ -1550,10 +1560,16 @@ impl SymbolResolver for Resolver {
let store = self.0.deferred_eval_store.store.read();
Python::with_gil(|py| -> PyResult<Result<(), String>> {
for (variables, constraints, name) in store.iter() {
let constraints: &PyAny = constraints.as_ref(py);
let constraints = constraints.bind(py);
for (i, var) in variables.iter().enumerate() {
if let Ok(constr) = constraints.get_item(i) {
match self.0.get_pyty_obj_type(py, constr, unifier, defs, primitives)? {
match self.0.get_pyty_obj_type(
py,
constr.as_gil_ref(),
unifier,
defs,
primitives,
)? {
Ok((ty, _)) => {
if !unifier.is_concrete(ty, &[]) {
return Ok(Err(format!(