1
0
forked from M-Labs/nac3

Compare commits

...

3 Commits

Author SHA1 Message Date
ad94f59a9d artiq: Remove all uses to gil-refs APIs 2024-09-13 11:18:09 +08:00
246d2f6d05 artiq: Update to pyo3 v0.22 with gil-refs feature 2024-09-13 11:18:09 +08:00
19d183ed84 artiq: Update to pyo3 v0.21
With the extensive use of as_gil_ref. Will have to refactor those away
as well.
2024-09-13 11:18:06 +08:00
5 changed files with 250 additions and 176 deletions

34
Cargo.lock generated
View File

@ -167,7 +167,7 @@ version = "4.5.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.77",
@ -406,12 +406,6 @@ dependencies = [
"ahash",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@ -849,15 +843,15 @@ dependencies = [
[[package]]
name = "pyo3"
version = "0.21.2"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8"
checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset",
"parking_lot",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
@ -867,9 +861,9 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
version = "0.21.2"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50"
checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
dependencies = [
"once_cell",
"target-lexicon",
@ -877,9 +871,9 @@ dependencies = [
[[package]]
name = "pyo3-ffi"
version = "0.21.2"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403"
checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
dependencies = [
"libc",
"pyo3-build-config",
@ -887,9 +881,9 @@ dependencies = [
[[package]]
name = "pyo3-macros"
version = "0.21.2"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c"
checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
@ -899,11 +893,11 @@ dependencies = [
[[package]]
name = "pyo3-macros-backend"
version = "0.21.2"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c"
checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
dependencies = [
"heck 0.4.1",
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
@ -1191,7 +1185,7 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"rustversion",

View File

@ -10,7 +10,7 @@ crate-type = ["cdylib"]
[dependencies]
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
pyo3 = { version = "0.22", features = ["extension-module", "py-clone"] }
parking_lot = "0.12"
tempfile = "3.10"
nac3core = { path = "../nac3core" }

View File

@ -26,8 +26,8 @@ use nac3core::inkwell::{
};
use pyo3::{
prelude::*,
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
@ -970,7 +970,7 @@ pub fn attributes_writeback(
host_attributes: &PyObject,
) -> Result<(), String> {
Python::with_gil(|py| -> PyResult<Result<(), String>> {
let host_attributes: &PyList = host_attributes.downcast(py)?;
let host_attributes = host_attributes.downcast_bound::<PyList>(py)?;
let top_levels = ctx.top_level.definitions.read();
let globals = inner_resolver.global_value_ids.read();
let int32 = ctx.ctx.i32_type();
@ -978,7 +978,7 @@ pub fn attributes_writeback(
let mut values = Vec::new();
let mut scratch_buffer = Vec::new();
for val in (*globals).values() {
let val = val.as_ref(py);
let val = val.bind_borrowed(py);
let ty = inner_resolver.get_obj_type(
py,
val,
@ -1016,7 +1016,7 @@ pub fn attributes_writeback(
}
}
if !attributes.is_empty() {
let pydict = PyDict::new(py);
let pydict = PyDict::new_bound(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?;
@ -1026,7 +1026,7 @@ pub fn attributes_writeback(
let elem_ty = iter_type_vars(params).next().unwrap().ty;
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py);
let pydict = PyDict::new_bound(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((

View File

@ -40,9 +40,11 @@ use nac3core::nac3parser::{
};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::{
create_exception, exceptions,
prelude::*,
types::{PyBytes, PyDict, PySet},
};
use parking_lot::{Mutex, RwLock};
@ -148,7 +150,7 @@ impl Nac3 {
registered_class_ids: &HashSet<u64>,
) -> PyResult<()> {
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let module: &PyAny = module.extract(py)?;
let module = module.bind_borrowed(py);
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?))
})?;
@ -174,14 +176,14 @@ impl Nac3 {
// Drop unregistered (i.e. host-only) base classes.
bases.retain(|base| {
Python::with_gil(|py| -> PyResult<bool> {
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let module = module.bind_borrowed(py);
let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?;
match &base.node {
ExprKind::Name { id, .. } => {
if *id == "Exception".into() {
Ok(true)
} else {
let base_obj =
module.getattr(py, id.to_string().as_str())?;
let base_obj = module.getattr(id.to_string().as_str())?;
let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id))
}
@ -361,10 +363,10 @@ impl Nac3 {
fn compile_method<T>(
&self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
embedding_map: &PyAny,
args: Vec<Bound<PyAny>>,
embedding_map: &Bound<PyAny>,
py: Python,
link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> {
@ -376,8 +378,8 @@ impl Nac3 {
size_t,
);
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let builtins = PyModule::import_bound(py, "builtins")?;
let typings = PyModule::import_bound(py, "typing")?;
let id_fn = builtins.getattr("id")?;
let issubclass = builtins.getattr("issubclass")?;
let exn_class = builtins.getattr("Exception")?;
@ -415,13 +417,17 @@ impl Nac3 {
let mut rpc_ids = vec![];
for (stmt, path, module) in &self.top_levels {
let py_module: &PyAny = module.extract(py)?;
let py_module = module.bind_borrowed(py);
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
let helper = helper.clone();
let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
if issubclass
.call1((class.as_borrowed(), exn_class.as_borrowed()))
.unwrap()
.extract()
.unwrap()
&& class.getattr("artiq_builtin").is_err()
{
class_obj = Some(class);
@ -434,8 +440,8 @@ impl Nac3 {
let (name_to_pyid, resolver) =
module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let members: &PyDict =
py_module.getattr("__dict__").unwrap().downcast().unwrap();
let members = py_module.getattr("__dict__").unwrap();
let members = members.downcast::<PyDict>().unwrap();
for (key, val) in members {
let key: &str = key.extract().unwrap();
let val = id_fn.call1((val,)).unwrap().extract().unwrap();
@ -513,15 +519,15 @@ impl Nac3 {
}
}
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
let id_fun = PyModule::import_bound(py, "builtins")?.getattr("id")?;
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let module = PyModule::new(py, "tmp")?;
let module = PyModule::new_bound(py, "tmp")?;
module.add("base", obj)?;
name_to_pyid.insert("base".into(), id_fun.call1((obj,))?.extract()?);
let mut arg_names = vec![];
for (i, arg) in args.into_iter().enumerate() {
let name = format!("tmp{i}");
module.add(&name, arg)?;
module.add(&*name, arg.clone())?;
name_to_pyid.insert(name.clone().into(), id_fun.call1((arg,))?.extract()?);
arg_names.push(name);
}
@ -900,7 +906,7 @@ fn add_exceptions(
#[pymethods]
impl Nac3 {
#[new]
fn new(isa: &str, artiq_builtins: &PyDict, py: Python) -> PyResult<Self> {
fn new(isa: &str, artiq_builtins: &Bound<PyDict>, py: Python) -> PyResult<Self> {
let isa = match isa {
"host" => Isa::Host,
"rv32g" => Isa::RiscV32G,
@ -964,43 +970,50 @@ impl Nac3 {
),
];
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let builtins_mod = PyModule::import_bound(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
let numpy_mod = PyModule::import_bound(py, "numpy").unwrap();
let typing_mod = PyModule::import_bound(py, "typing").unwrap();
let types_mod = PyModule::import_bound(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_attr_id = |obj: &PyModule, attr| {
let get_id = |x: Borrowed<PyAny>| id_fn.call1((x,)).and_then(|id| id.extract()).unwrap();
let get_attr_id = |obj: Borrowed<PyModule>, attr| {
id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
virtual_id: get_id(
artiq_builtins.get_item("virtual").ok().flatten().unwrap().as_borrowed(),
),
generic_alias: (
get_attr_id(typing_mod, "_GenericAlias"),
get_attr_id(types_mod, "GenericAlias"),
get_attr_id(typing_mod.as_borrowed(), "_GenericAlias"),
get_attr_id(types_mod.as_borrowed(), "GenericAlias"),
),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap().as_borrowed()),
typevar: get_attr_id(typing_mod.as_borrowed(), "TypeVar"),
const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
artiq_builtins
.get_item("_ConstGenericMarker")
.ok()
.flatten()
.unwrap()
.as_borrowed(),
),
int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"),
uint32: get_attr_id(numpy_mod, "uint32"),
uint64: get_attr_id(numpy_mod, "uint64"),
bool: get_attr_id(builtins_mod, "bool"),
np_bool_: get_attr_id(numpy_mod, "bool_"),
string: get_attr_id(builtins_mod, "str"),
np_str_: get_attr_id(numpy_mod, "str_"),
float: get_attr_id(builtins_mod, "float"),
float64: get_attr_id(numpy_mod, "float64"),
list: get_attr_id(builtins_mod, "list"),
ndarray: get_attr_id(numpy_mod, "ndarray"),
tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
int: get_attr_id(builtins_mod.as_borrowed(), "int"),
int32: get_attr_id(numpy_mod.as_borrowed(), "int32"),
int64: get_attr_id(numpy_mod.as_borrowed(), "int64"),
uint32: get_attr_id(numpy_mod.as_borrowed(), "uint32"),
uint64: get_attr_id(numpy_mod.as_borrowed(), "uint64"),
bool: get_attr_id(builtins_mod.as_borrowed(), "bool"),
np_bool_: get_attr_id(numpy_mod.as_borrowed(), "bool_"),
string: get_attr_id(builtins_mod.as_borrowed(), "str"),
np_str_: get_attr_id(numpy_mod.as_borrowed(), "str_"),
float: get_attr_id(builtins_mod.as_borrowed(), "float"),
float64: get_attr_id(numpy_mod.as_borrowed(), "float64"),
list: get_attr_id(builtins_mod.as_borrowed(), "list"),
ndarray: get_attr_id(numpy_mod.as_borrowed(), "ndarray"),
tuple: get_attr_id(builtins_mod.as_borrowed(), "tuple"),
exception: get_attr_id(builtins_mod.as_borrowed(), "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap().as_borrowed()),
};
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
@ -1025,21 +1038,21 @@ impl Nac3 {
})
}
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
fn analyze(&mut self, functions: &Bound<PySet>, classes: &Bound<PySet>) -> PyResult<()> {
let (modules, class_ids) =
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new();
let mut class_ids: HashSet<u64> = HashSet::new();
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
let id_fn = PyModule::import_bound(py, "builtins")?.getattr("id")?;
let getmodule_fn = PyModule::import_bound(py, "inspect")?.getattr("getmodule")?;
for function in functions {
let module = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
for class in classes {
let module = getmodule_fn.call1((class,))?.extract()?;
let module = getmodule_fn.call1((class.as_borrowed(),))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
class_ids.insert(id_fn.call1((class,))?.extract()?);
}
@ -1054,11 +1067,11 @@ impl Nac3 {
fn compile_method_to_file(
&mut self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
args: Vec<Bound<PyAny>>,
filename: &str,
embedding_map: &PyAny,
embedding_map: &Bound<PyAny>,
py: Python,
) -> PyResult<()> {
let target_machine = self.get_llvm_target_machine();
@ -1100,10 +1113,10 @@ impl Nac3 {
fn compile_method_to_mem(
&mut self,
obj: &PyAny,
obj: &Bound<PyAny>,
method_name: &str,
args: Vec<&PyAny>,
embedding_map: &PyAny,
args: Vec<Bound<PyAny>>,
embedding_map: &Bound<PyAny>,
py: Python,
) -> PyResult<PyObject> {
let target_machine = self.get_llvm_target_machine();
@ -1122,7 +1135,7 @@ impl Nac3 {
working_directory.join("module.o").to_string_lossy().to_string(),
)?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
Ok(PyBytes::new_bound(py, &fs::read(filename).unwrap()).into())
};
self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
@ -1132,7 +1145,7 @@ impl Nac3 {
.write_to_memory_buffer(module, FileType::Object)
.expect("couldn't write module to object file buffer");
if let Ok(dyn_lib) = Linker::ld(object_mem.as_slice()) {
Ok(PyBytes::new(py, &dyn_lib).into())
Ok(PyBytes::new_bound(py, &dyn_lib).into())
} else {
Err(CompileError::new_err("linker failed to process object file"))
}
@ -1149,14 +1162,14 @@ extern "C" {
}
#[pymodule]
fn nac3artiq(py: Python, m: &PyModule) -> PyResult<()> {
fn nac3artiq(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
#[cfg(feature = "init-llvm-profile")]
unsafe {
__llvm_profile_initialize();
}
Target::initialize_all(&InitializationConfig::default());
m.add("CompileError", py.get_type::<CompileError>())?;
m.add("CompileError", py.get_type_bound::<CompileError>())?;
m.add_class::<Nac3>()?;
Ok(())
}

View File

@ -25,8 +25,8 @@ use nac3core::{
};
use parking_lot::RwLock;
use pyo3::{
prelude::*,
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use std::{
collections::{HashMap, HashSet},
@ -173,7 +173,7 @@ impl StaticValue for PythonValue {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
self.resolver
.get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty)
.get_obj_value(py, self.value.bind_borrowed(py), ctx, generator, expected_ty)
.map(Option::unwrap)
})
.map_err(|e| e.to_string())
@ -242,10 +242,10 @@ impl StaticValue for PythonValue {
let ty = helper.type_fn.call1(py, (&self.value,))?;
let ty_id: u64 = helper.id_fn.call1(py, (ty,))?.extract(py)?;
assert_eq!(ty_id, self.resolver.primitive_ids.tuple);
let tup: &PyTuple = self.value.extract(py)?;
let tup = self.value.downcast_bound::<PyTuple>(py)?;
let elem = tup.get_item(index as usize)?;
let id = self.resolver.helper.id_fn.call1(py, (elem,))?.extract(py)?;
Ok(Some((id, elem.into())))
let id = self.resolver.helper.id_fn.call1(py, (elem.as_borrowed(),))?.extract(py)?;
Ok(Some((id, elem.unbind())))
})
.unwrap()
.map(|(id, obj)| {
@ -263,21 +263,26 @@ impl InnerResolver {
fn get_list_elem_type(
&self,
py: Python,
list: &PyAny,
list: Borrowed<PyAny>,
len: usize,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
) -> PyResult<Result<Type, String>> {
let mut ty = match self.get_obj_type(py, list.get_item(0)?, unifier, defs, primitives)? {
let mut ty = match self.get_obj_type(
py,
list.get_item(0)?.as_borrowed(),
unifier,
defs,
primitives,
)? {
Ok(t) => t,
Err(e) => return Ok(Err(format!("type error ({e}) at element #0 of the list"))),
};
for i in 1..len {
let b = match list
.get_item(i)
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))??
{
let b = match list.get_item(i).map(|elem| {
self.get_obj_type(py, elem.as_borrowed(), unifier, defs, primitives)
})?? {
Ok(t) => t,
Err(e) => return Ok(Err(format!("type error ({e}) at element #{i} of the list"))),
};
@ -303,7 +308,7 @@ impl InnerResolver {
fn get_pyty_obj_type(
&self,
py: Python,
pyty: &PyAny,
pyty: Borrowed<PyAny>,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
@ -391,7 +396,8 @@ impl InnerResolver {
(unifier.add_ty(ty), false)
}))
} else if ty_ty_id == self.primitive_ids.typevar {
let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap();
let name = pyty.getattr("__name__").unwrap();
let name: &str = name.extract().unwrap();
let (constraint_types, is_const_generic) = {
let constraints = pyty.getattr("__constraints__").unwrap();
let mut result: Vec<Type> = vec![];
@ -400,7 +406,8 @@ impl InnerResolver {
let mut is_const_generic = false;
for i in 0usize.. {
if let Ok(constr) = constraints.get_item(i) {
let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?;
let constr_id: u64 =
self.helper.id_fn.call1(py, (constr.as_borrowed(),))?.extract(py)?;
if constr_id == self.primitive_ids.const_generic_marker {
is_const_generic = true;
continue;
@ -410,7 +417,7 @@ impl InnerResolver {
result.push(unifier.get_dummy_var().ty);
} else {
result.push({
match self.get_pyty_obj_type(py, constr, unifier, defs, primitives)? {
match self.get_pyty_obj_type(py, constr.as_borrowed(), unifier, defs, primitives)? {
Ok((ty, _)) => {
if unifier.is_concrete(ty, &[]) {
ty
@ -461,22 +468,27 @@ impl InnerResolver {
{
let origin = self.helper.origin_ty_fn.call1(py, (pyty,))?;
let args = self.helper.args_ty_fn.call1(py, (pyty,))?;
let args: &PyTuple = args.downcast(py)?;
let origin_ty =
match self.get_pyty_obj_type(py, origin.as_ref(py), unifier, defs, primitives)? {
Ok((ty, false)) => ty,
Ok((_, true)) => {
return Ok(Err("instantiated type does not take type parameters".into()))
}
Err(err) => return Ok(Err(err)),
};
let args = args.downcast_bound::<PyTuple>(py)?;
let origin_ty = match self.get_pyty_obj_type(
py,
origin.bind_borrowed(py),
unifier,
defs,
primitives,
)? {
Ok((ty, false)) => ty,
Ok((_, true)) => {
return Ok(Err("instantiated type does not take type parameters".into()))
}
Err(err) => return Ok(Err(err)),
};
match &*unifier.get_ty(origin_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(
py,
args.get_item(0)?,
args.get_item(0)?.as_borrowed(),
unifier,
defs,
primitives,
@ -522,9 +534,15 @@ impl InnerResolver {
// npt.NDArray[T] == np.ndarray[Any, np.dtype[T]]
let ndarray_dtype_pyty =
self.helper.args_ty_fn.call1(py, (args.get_item(1)?,))?;
let dtype = ndarray_dtype_pyty.downcast::<PyTuple>(py)?.get_item(0)?;
let dtype = ndarray_dtype_pyty.downcast_bound::<PyTuple>(py)?.get_item(0)?;
let ty = match self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)? {
let ty = match self.get_pyty_obj_type(
py,
dtype.as_borrowed(),
unifier,
defs,
primitives,
)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err)),
};
@ -540,7 +558,7 @@ impl InnerResolver {
TypeEnum::TTuple { .. } => {
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives))
.map(|x| self.get_pyty_obj_type(py, x.as_borrowed(), unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
@ -573,7 +591,7 @@ impl InnerResolver {
}
let args = match args
.iter()
.map(|x| self.get_pyty_obj_type(py, x, unifier, defs, primitives))
.map(|x| self.get_pyty_obj_type(py, x.as_borrowed(), unifier, defs, primitives))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect::<Result<Vec<_>, _>>() {
@ -600,7 +618,7 @@ impl InnerResolver {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(
py,
args.get_item(0)?,
args.get_item(0)?.as_borrowed(),
unifier,
defs,
primitives,
@ -631,8 +649,7 @@ impl InnerResolver {
false,
)))
} else {
let str_fn =
pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap();
let str_fn = PyModule::import_bound(py, "builtins").unwrap().getattr("repr").unwrap();
let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap();
Ok(Err(format!("{str_repr} is not registered with NAC3 (@nac3 decorator missing?)")))
}
@ -641,7 +658,7 @@ impl InnerResolver {
pub fn get_obj_type(
&self,
py: Python,
obj: &PyAny,
obj: Borrowed<PyAny>,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
@ -688,7 +705,7 @@ impl InnerResolver {
{
obj
} else {
ty.as_ref(py)
ty.bind_borrowed(py)
}
},
unifier,
@ -776,7 +793,8 @@ impl InnerResolver {
Ok(Ok(extracted_ty))
} else {
let dtype = obj.getattr("dtype")?.getattr("type")?;
let dtype_ty = self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)?;
let dtype_ty =
self.get_pyty_obj_type(py, dtype.as_borrowed(), unifier, defs, primitives)?;
match dtype_ty {
Ok((t, _)) => match unifier.unify(ty, t) {
Ok(()) => {
@ -795,10 +813,12 @@ impl InnerResolver {
}
}
(TypeEnum::TTuple { .. }, false) => {
let elements: &PyTuple = obj.downcast()?;
let elements = obj.downcast::<PyTuple>()?;
let types: Result<Result<Vec<_>, _>, _> = elements
.iter()
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
.map(|elem| {
self.get_obj_type(py, elem.as_borrowed(), unifier, defs, primitives)
})
.collect();
let types = types?;
Ok(types.map(|types| {
@ -834,7 +854,13 @@ impl InnerResolver {
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()));
}
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
let ty = match self.get_obj_type(
py,
field_data.as_borrowed(),
unifier,
defs,
primitives,
)? {
Ok(t) => t,
Err(e) => {
return Ok(Err(format!(
@ -869,15 +895,20 @@ impl InnerResolver {
Ok(d) => d,
Err(e) => return Ok(Err(format!("{e}"))),
};
let ty =
match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
Ok(t) => t,
Err(e) => {
return Ok(Err(format!(
"error when getting type of field `{name}` ({e})"
)))
}
};
let ty = match self.get_obj_type(
py,
field_data.as_borrowed(),
unifier,
defs,
primitives,
)? {
Ok(t) => t,
Err(e) => {
return Ok(Err(format!(
"error when getting type of field `{name}` ({e})"
)))
}
};
let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0);
if let Err(e) = unifier.unify(ty, field_ty) {
// field type mismatch
@ -909,32 +940,32 @@ impl InnerResolver {
// check integer bounds
if unifier.unioned(extracted_ty, primitives.int32) {
obj.extract::<i32>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of int32"))),
|_| Ok(Err(format!("{} is not in the range of int32", obj.as_unbound()))),
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.int64) {
obj.extract::<i64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of int64"))),
|_| Ok(Err(format!("{} is not in the range of int64", obj.as_unbound()))),
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.uint32) {
obj.extract::<u32>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of uint32"))),
|_| Ok(Err(format!("{} is not in the range of uint32", obj.as_unbound()))),
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.uint64) {
obj.extract::<u64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of uint64"))),
|_| Ok(Err(format!("{} is not in the range of uint64", obj.as_unbound()))),
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.bool) {
obj.extract::<bool>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of bool"))),
|_| Ok(Err(format!("{} is not in the range of bool", obj.as_unbound()))),
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.float) {
obj.extract::<f64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of float64"))),
|_| Ok(Err(format!("{} is not in the range of float64", obj.as_unbound()))),
|_| Ok(Ok(extracted_ty)),
)
} else {
@ -947,7 +978,7 @@ impl InnerResolver {
pub fn get_obj_value<'ctx>(
&self,
py: Python,
obj: &PyAny,
obj: Borrowed<PyAny>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
expected_ty: Type,
@ -1017,15 +1048,19 @@ impl InnerResolver {
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
self.global_value_ids.write().insert(id, obj.as_unbound().clone());
}
let arr: Result<Option<Vec<_>>, _> = (0..len)
.map(|i| {
obj.get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, elem_ty).map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
self.get_obj_value(py, elem.as_borrowed(), ctx, generator, elem_ty).map_err(
|e| {
super::CompileError::new_err(format!(
"Error getting element {i}: {e}"
))
},
)
})
})
.collect();
@ -1100,7 +1135,7 @@ impl InnerResolver {
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
self.global_value_ids.write().insert(id, obj.as_unbound().clone());
}
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
@ -1118,15 +1153,23 @@ impl InnerResolver {
};
// Obtain the shape of the ndarray
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
let shape_tuple = obj.getattr("shape")?;
let shape_tuple = shape_tuple.downcast::<PyTuple>()?;
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
.iter()
.enumerate()
.map(|(i, elem)| {
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
self.get_obj_value(
py,
elem.as_borrowed(),
ctx,
generator,
ctx.primitives.usize(),
)
.map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
})
.collect();
let shape_values = shape_values?.unwrap();
@ -1147,9 +1190,12 @@ impl InnerResolver {
let data: Result<Option<Vec<_>>, _> = (0..sz)
.map(|i| {
obj.getattr("flat")?.get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
self.get_obj_value(py, elem.as_borrowed(), ctx, generator, ndarray_dtype)
.map_err(|e| {
super::CompileError::new_err(format!(
"Error getting element {i}: {e}"
))
})
})
})
.collect();
@ -1214,14 +1260,14 @@ impl InnerResolver {
};
let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?;
let elements = obj.downcast::<PyTuple>()?;
assert_eq!(elements.len(), tup_tys.len());
let val: Result<Option<Vec<_>>, _> = elements
.iter()
.enumerate()
.zip(tup_tys)
.map(|((i, elem), ty)| {
self.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| {
self.get_obj_value(py, elem.as_borrowed(), ctx, generator, *ty).map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
})
@ -1250,7 +1296,7 @@ impl InnerResolver {
match self
.get_obj_value(
py,
obj.getattr("_nac3_option").unwrap(),
obj.getattr("_nac3_option").unwrap().as_borrowed(),
ctx,
generator,
option_val_ty,
@ -1274,7 +1320,7 @@ impl InnerResolver {
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
self.global_value_ids.write().insert(id, obj.as_unbound().clone());
}
let global = ctx.module.add_global(
v.get_type(),
@ -1310,7 +1356,7 @@ impl InnerResolver {
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
self.global_value_ids.write().insert(id, obj.as_unbound().clone());
}
// should be classes
let definition =
@ -1322,7 +1368,7 @@ impl InnerResolver {
.map(|(name, ty, _)| {
self.get_obj_value(
py,
obj.getattr(name.to_string().as_str())?,
obj.getattr(name.to_string().as_str())?.as_borrowed(),
ctx,
generator,
*ty,
@ -1349,7 +1395,7 @@ impl InnerResolver {
fn get_default_param_obj_value(
&self,
py: Python,
obj: &PyAny,
obj: Borrowed<PyAny>,
) -> PyResult<Result<SymbolValue, String>> {
let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
let ty_id: u64 =
@ -1376,16 +1422,21 @@ impl InnerResolver {
let val: f64 = obj.extract()?;
Ok(SymbolValue::Double(val))
} else if ty_id == self.primitive_ids.tuple {
let elements: &PyTuple = obj.downcast()?;
let elements: Result<Result<Vec<_>, String>, _> =
elements.iter().map(|elem| self.get_default_param_obj_value(py, elem)).collect();
let elements = obj.downcast::<PyTuple>()?;
let elements: Result<Result<Vec<_>, String>, _> = elements
.iter()
.map(|elem| self.get_default_param_obj_value(py, elem.as_borrowed()))
.collect();
elements?.map(SymbolValue::Tuple)
} else if ty_id == self.primitive_ids.option {
if id == self.primitive_ids.none {
Ok(SymbolValue::OptionNone)
} else {
self.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())?
.map(|v| SymbolValue::OptionSome(Box::new(v)))
self.get_default_param_obj_value(
py,
obj.getattr("_nac3_option").unwrap().as_borrowed(),
)?
.map(|v| SymbolValue::OptionSome(Box::new(v)))
}
} else {
Err("only primitives values, option and tuple can be default parameter value".into())
@ -1400,13 +1451,14 @@ impl SymbolResolver for Resolver {
};
Python::with_gil(|py| -> PyResult<Option<SymbolValue>> {
let obj: &PyAny = self.0.module.extract(py)?;
let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap();
let obj = self.0.module.downcast_bound::<PyAny>(py)?;
let members = obj.getattr("__dict__").unwrap();
let members = members.downcast::<PyDict>().unwrap();
let mut sym_value = None;
for (key, val) in members {
let key: &str = key.extract()?;
if key == id.to_string() {
if let Ok(Ok(v)) = self.0.get_default_param_obj_value(py, val) {
if let Ok(Ok(v)) = self.0.get_default_param_obj_value(py, val.as_borrowed()) {
sym_value = Some(v);
}
break;
@ -1440,13 +1492,20 @@ impl SymbolResolver for Resolver {
Ok(t)
} else {
Python::with_gil(|py| -> PyResult<Result<Type, String>> {
let obj: &PyAny = self.0.module.extract(py)?;
let obj = self.0.module.downcast_bound::<PyAny>(py)?;
let mut sym_ty = Err(format!("cannot find symbol `{str}`"));
let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap();
let members = obj.getattr("__dict__").unwrap();
let members = members.downcast::<PyDict>().unwrap();
for (key, val) in members {
let key: &str = key.extract()?;
if key == str.to_string() {
sym_ty = self.0.get_obj_type(py, val, unifier, defs, primitives)?;
sym_ty = self.0.get_obj_type(
py,
val.as_borrowed(),
unifier,
defs,
primitives,
)?;
break;
}
}
@ -1474,13 +1533,15 @@ impl SymbolResolver for Resolver {
}
.or_else(|| {
Python::with_gil(|py| -> PyResult<Option<(u64, PyObject)>> {
let obj: &PyAny = self.0.module.extract(py)?;
let obj = self.0.module.downcast_bound::<PyAny>(py)?;
let mut sym_value: Option<(u64, PyObject)> = None;
let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap();
let members = obj.getattr("__dict__").unwrap();
let members = members.downcast::<PyDict>().unwrap();
for (key, val) in members {
let key: &str = key.extract()?;
if key == id.to_string() {
let id = self.0.helper.id_fn.call1(py, (val,))?.extract(py)?;
let id =
self.0.helper.id_fn.call1(py, (val.as_borrowed(),))?.extract(py)?;
sym_value = Some((id, val.extract()?));
break;
}
@ -1549,10 +1610,16 @@ impl SymbolResolver for Resolver {
let store = self.0.deferred_eval_store.store.read();
Python::with_gil(|py| -> PyResult<Result<(), String>> {
for (variables, constraints, name) in store.iter() {
let constraints: &PyAny = constraints.as_ref(py);
let constraints = constraints.bind(py);
for (i, var) in variables.iter().enumerate() {
if let Ok(constr) = constraints.get_item(i) {
match self.0.get_pyty_obj_type(py, constr, unifier, defs, primitives)? {
match self.0.get_pyty_obj_type(
py,
constr.as_borrowed(),
unifier,
defs,
primitives,
)? {
Ok((ty, _)) => {
if !unifier.is_concrete(ty, &[]) {
return Ok(Err(format!(