forked from M-Labs/nac3
1
0
Fork 0

artiq: Update to pyo3 v0.21

With the extensive use of as_gil_ref. Will have to refactor those away
as well.
This commit is contained in:
David Mak 2024-07-05 12:54:15 +08:00
parent 25d2de67f7
commit 317503679e
5 changed files with 127 additions and 104 deletions

36
Cargo.lock generated
View File

@ -1398,9 +1398,9 @@ dependencies = [
[[package]]
name = "windows-targets"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
@ -1414,51 +1414,51 @@ dependencies = [
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "yaml-rust"

View File

@ -10,7 +10,7 @@ crate-type = ["cdylib"]
[dependencies]
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
pyo3 = { version = "0.21", features = ["extension-module"] }
parking_lot = "0.12"
tempfile = "3.10"
nac3parser = { path = "../nac3parser" }

View File

@ -17,8 +17,8 @@ use inkwell::{
};
use pyo3::{
prelude::*,
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
@ -624,7 +624,7 @@ pub fn attributes_writeback(
host_attributes: &PyObject,
) -> Result<(), String> {
Python::with_gil(|py| -> PyResult<Result<(), String>> {
let host_attributes: &PyList = host_attributes.downcast(py)?;
let host_attributes = host_attributes.downcast_bound::<PyList>(py)?;
let top_levels = ctx.top_level.definitions.read();
let globals = inner_resolver.global_value_ids.read();
let int32 = ctx.ctx.i32_type();
@ -632,10 +632,10 @@ pub fn attributes_writeback(
let mut values = Vec::new();
let mut scratch_buffer = Vec::new();
for val in (*globals).values() {
let val = val.as_ref(py);
let val = val.bind_borrowed(py);
let ty = inner_resolver.get_obj_type(
py,
val,
val.as_gil_ref(),
&mut ctx.unifier,
&top_levels,
&ctx.primitives,
@ -651,7 +651,9 @@ pub fn attributes_writeback(
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global
let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
let obj = inner_resolver
.get_obj_value(py, val.as_gil_ref(), ctx, generator, ty)?
.unwrap();
for (name, (field_ty, is_mutable)) in fields {
if !is_mutable {
continue;
@ -670,7 +672,7 @@ pub fn attributes_writeback(
}
}
if !attributes.is_empty() {
let pydict = PyDict::new(py);
let pydict = PyDict::new_bound(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?;
@ -680,12 +682,14 @@ pub fn attributes_writeback(
let elem_ty = iter_type_vars(params).next().unwrap().ty;
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py);
let pydict = PyDict::new_bound(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((
ty,
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
inner_resolver
.get_obj_value(py, val.as_gil_ref(), ctx, generator, ty)?
.unwrap(),
));
}
}

View File

@ -39,9 +39,11 @@ use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::{
create_exception, exceptions,
prelude::*,
types::{PyBytes, PyDict, PySet},
};
use parking_lot::{Mutex, RwLock};
@ -173,7 +175,7 @@ impl Nac3 {
// Drop unregistered (i.e. host-only) base classes.
bases.retain(|base| {
Python::with_gil(|py| -> PyResult<bool> {
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?;
match &base.node {
ExprKind::Name { id, .. } => {
if *id == "Exception".into() {
@ -302,10 +304,10 @@ impl Nac3 {
fn compile_method<T>(
&self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
embedding_map: &PyAny,
embedding_map: &Bound<PyAny>,
py: Python,
link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> {
@ -316,8 +318,8 @@ impl Nac3 {
size_t,
);
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let builtins = PyModule::import_bound(py, "builtins")?;
let typings = PyModule::import_bound(py, "typing")?;
let id_fn = builtins.getattr("id")?;
let issubclass = builtins.getattr("issubclass")?;
let exn_class = builtins.getattr("Exception")?;
@ -361,7 +363,7 @@ impl Nac3 {
let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
if issubclass.call1((class, exn_class.as_gil_ref())).unwrap().extract().unwrap()
&& class.getattr("artiq_builtin").is_err()
{
class_obj = Some(class);
@ -454,15 +456,15 @@ impl Nac3 {
}
}
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
let id_fun = PyModule::import_bound(py, "builtins")?.getattr("id")?;
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let module = PyModule::new(py, "tmp")?;
let module = PyModule::new_bound(py, "tmp")?;
module.add("base", obj)?;
name_to_pyid.insert("base".into(), id_fun.call1((obj,))?.extract()?);
let mut arg_names = vec![];
for (i, arg) in args.into_iter().enumerate() {
let name = format!("tmp{i}");
module.add(&name, arg)?;
module.add(&*name, arg)?;
name_to_pyid.insert(name.clone().into(), id_fun.call1((arg,))?.extract()?);
arg_names.push(name);
}
@ -834,7 +836,7 @@ fn add_exceptions(
#[pymethods]
impl Nac3 {
#[new]
fn new(isa: &str, artiq_builtins: &PyDict, py: Python) -> PyResult<Self> {
fn new(isa: &str, artiq_builtins: &Bound<PyDict>, py: Python) -> PyResult<Self> {
let isa = match isa {
"host" => Isa::Host,
"rv32g" => Isa::RiscV32G,
@ -896,43 +898,45 @@ impl Nac3 {
),
];
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let builtins_mod = PyModule::import_bound(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
let numpy_mod = PyModule::import_bound(py, "numpy").unwrap();
let typing_mod = PyModule::import_bound(py, "typing").unwrap();
let types_mod = PyModule::import_bound(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(|id| id.extract()).unwrap();
let get_attr_id = |obj: &PyModule, attr| {
id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
virtual_id: get_id(
artiq_builtins.get_item("virtual").ok().flatten().unwrap().as_gil_ref(),
),
generic_alias: (
get_attr_id(typing_mod, "_GenericAlias"),
get_attr_id(types_mod, "GenericAlias"),
get_attr_id(typing_mod.as_gil_ref(), "_GenericAlias"),
get_attr_id(types_mod.as_gil_ref(), "GenericAlias"),
),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap().as_gil_ref()),
typevar: get_attr_id(typing_mod.as_gil_ref(), "TypeVar"),
const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap().as_gil_ref(),
),
int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"),
uint32: get_attr_id(numpy_mod, "uint32"),
uint64: get_attr_id(numpy_mod, "uint64"),
bool: get_attr_id(builtins_mod, "bool"),
np_bool_: get_attr_id(numpy_mod, "bool_"),
string: get_attr_id(builtins_mod, "str"),
np_str_: get_attr_id(numpy_mod, "str_"),
float: get_attr_id(builtins_mod, "float"),
float64: get_attr_id(numpy_mod, "float64"),
list: get_attr_id(builtins_mod, "list"),
ndarray: get_attr_id(numpy_mod, "ndarray"),
tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
int: get_attr_id(builtins_mod.as_gil_ref(), "int"),
int32: get_attr_id(numpy_mod.as_gil_ref(), "int32"),
int64: get_attr_id(numpy_mod.as_gil_ref(), "int64"),
uint32: get_attr_id(numpy_mod.as_gil_ref(), "uint32"),
uint64: get_attr_id(numpy_mod.as_gil_ref(), "uint64"),
bool: get_attr_id(builtins_mod.as_gil_ref(), "bool"),
np_bool_: get_attr_id(numpy_mod.as_gil_ref(), "bool_"),
string: get_attr_id(builtins_mod.as_gil_ref(), "str"),
np_str_: get_attr_id(numpy_mod.as_gil_ref(), "str_"),
float: get_attr_id(builtins_mod.as_gil_ref(), "float"),
float64: get_attr_id(numpy_mod.as_gil_ref(), "float64"),
list: get_attr_id(builtins_mod.as_gil_ref(), "list"),
ndarray: get_attr_id(numpy_mod.as_gil_ref(), "ndarray"),
tuple: get_attr_id(builtins_mod.as_gil_ref(), "tuple"),
exception: get_attr_id(builtins_mod.as_gil_ref(), "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap().as_gil_ref()),
};
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
@ -957,23 +961,23 @@ impl Nac3 {
})
}
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
fn analyze(&mut self, functions: &Bound<PySet>, classes: &Bound<PySet>) -> PyResult<()> {
let (modules, class_ids) =
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new();
let mut class_ids: HashSet<u64> = HashSet::new();
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import_bound(py, "inspect")?.getattr("getmodule")?;
for function in functions {
let module = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
for class in classes {
let module = getmodule_fn.call1((class,))?.extract()?;
let module = getmodule_fn.call1((class.as_gil_ref(),))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
class_ids.insert(id_fn.call1((class,))?.extract()?);
class_ids.insert(id_fn.call1((class.as_gil_ref(),))?.extract()?);
}
Ok((modules, class_ids))
})?;
@ -986,11 +990,11 @@ impl Nac3 {
fn compile_method_to_file(
&mut self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
filename: &str,
embedding_map: &PyAny,
embedding_map: &Bound<PyAny>,
py: Python,
) -> PyResult<()> {
let target_machine = self.get_llvm_target_machine();
@ -1032,10 +1036,10 @@ impl Nac3 {
fn compile_method_to_mem(
&mut self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
embedding_map: &PyAny,
embedding_map: &Bound<PyAny>,
py: Python,
) -> PyResult<PyObject> {
let target_machine = self.get_llvm_target_machine();
@ -1054,7 +1058,7 @@ impl Nac3 {
working_directory.join("module.o").to_string_lossy().to_string(),
)?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
Ok(PyBytes::new_bound(py, &fs::read(filename).unwrap()).into())
};
self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
@ -1064,7 +1068,7 @@ impl Nac3 {
.write_to_memory_buffer(module, FileType::Object)
.expect("couldn't write module to object file buffer");
if let Ok(dyn_lib) = Linker::ld(object_mem.as_slice()) {
Ok(PyBytes::new(py, &dyn_lib).into())
Ok(PyBytes::new_bound(py, &dyn_lib).into())
} else {
Err(CompileError::new_err("linker failed to process object file"))
}
@ -1081,14 +1085,14 @@ extern "C" {
}
#[pymodule]
fn nac3artiq(py: Python, m: &PyModule) -> PyResult<()> {
fn nac3artiq(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
#[cfg(feature = "init-llvm-profile")]
unsafe {
__llvm_profile_initialize();
}
Target::initialize_all(&InitializationConfig::default());
m.add("CompileError", py.get_type::<CompileError>())?;
m.add("CompileError", py.get_type_bound::<CompileError>())?;
m.add_class::<Nac3>()?;
Ok(())
}

View File

@ -1,3 +1,4 @@
use crate::PrimitivePythonId;
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
@ -23,8 +24,8 @@ use nac3core::{
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use pyo3::{
prelude::*,
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use std::{
collections::{HashMap, HashSet},
@ -34,8 +35,6 @@ use std::{
},
};
use crate::PrimitivePythonId;
pub enum PrimitiveValue {
I32(i32),
I64(i64),
@ -172,7 +171,7 @@ impl StaticValue for PythonValue {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
self.resolver
.get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty)
.get_obj_value(py, self.value.bind(py).as_gil_ref(), ctx, generator, expected_ty)
.map(Option::unwrap)
})
.map_err(|e| e.to_string())
@ -460,9 +459,14 @@ impl InnerResolver {
{
let origin = self.helper.origin_ty_fn.call1(py, (pyty,))?;
let args = self.helper.args_ty_fn.call1(py, (pyty,))?;
let args: &PyTuple = args.downcast(py)?;
let origin_ty =
match self.get_pyty_obj_type(py, origin.as_ref(py), unifier, defs, primitives)? {
let args = args.downcast_bound::<PyTuple>(py)?;
let origin_ty = match self.get_pyty_obj_type(
py,
origin.bind(py).as_gil_ref(),
unifier,
defs,
primitives,
)? {
Ok((ty, false)) => ty,
Ok((_, true)) => {
return Ok(Err("instantiated type does not take type parameters".into()))
@ -475,7 +479,7 @@ impl InnerResolver {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(
py,
args.get_item(0)?,
args.get_item(0)?.as_gil_ref(),
unifier,
defs,
primitives,
@ -521,9 +525,15 @@ impl InnerResolver {
// npt.NDArray[T] == np.ndarray[Any, np.dtype[T]]
let ndarray_dtype_pyty =
self.helper.args_ty_fn.call1(py, (args.get_item(1)?,))?;
let dtype = ndarray_dtype_pyty.downcast::<PyTuple>(py)?.get_item(0)?;
let dtype = ndarray_dtype_pyty.downcast_bound::<PyTuple>(py)?.get_item(0)?;
let ty = match self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)? {
let ty = match self.get_pyty_obj_type(
py,
dtype.as_gil_ref(),
unifier,
defs,
primitives,
)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err)),
};
@ -539,7 +549,7 @@ impl InnerResolver {
TypeEnum::TTuple { .. } => {
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives))
.map(|x| self.get_pyty_obj_type(py, x.as_gil_ref(), unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
@ -569,7 +579,7 @@ impl InnerResolver {
}
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives))
.map(|x| self.get_pyty_obj_type(py, x.as_gil_ref(), unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
@ -596,7 +606,7 @@ impl InnerResolver {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(
py,
args.get_item(0)?,
args.get_item(0)?.as_gil_ref(),
unifier,
defs,
primitives,
@ -627,8 +637,7 @@ impl InnerResolver {
false,
)))
} else {
let str_fn =
pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap();
let str_fn = PyModule::import_bound(py, "builtins").unwrap().getattr("repr").unwrap();
let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap();
Ok(Err(format!("{str_repr} is not registered with NAC3 (@nac3 decorator missing?)")))
}
@ -684,7 +693,7 @@ impl InnerResolver {
{
obj
} else {
ty.as_ref(py)
ty.bind(py).as_gil_ref()
}
},
unifier,
@ -1534,10 +1543,16 @@ impl SymbolResolver for Resolver {
let store = self.0.deferred_eval_store.store.read();
Python::with_gil(|py| -> PyResult<Result<(), String>> {
for (variables, constraints, name) in store.iter() {
let constraints: &PyAny = constraints.as_ref(py);
let constraints = constraints.bind(py);
for (i, var) in variables.iter().enumerate() {
if let Ok(constr) = constraints.get_item(i) {
match self.0.get_pyty_obj_type(py, constr, unifier, defs, primitives)? {
match self.0.get_pyty_obj_type(
py,
constr.as_gil_ref(),
unifier,
defs,
primitives,
)? {
Ok((ty, _)) => {
if !unifier.is_concrete(ty, &[]) {
return Ok(Err(format!(