forked from M-Labs/nac3
[artiq] Reimplement get_obj_value for strided ndarray
Based on 7ef93472
: artiq: reimplement get_obj_value to use ndarray with
strides
This commit is contained in:
parent
f4c5038b95
commit
8b3429d62a
@ -10,13 +10,14 @@ use itertools::Itertools;
|
||||
use parking_lot::RwLock;
|
||||
use pyo3::{
|
||||
types::{PyDict, PyTuple},
|
||||
PyAny, PyObject, PyResult, Python,
|
||||
PyAny, PyErr, PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use super::PrimitivePythonId;
|
||||
use nac3core::{
|
||||
codegen::{
|
||||
types::{ndarray::NDArrayType, ProxyType},
|
||||
values::ndarray::make_contiguous_strides,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
inkwell::{
|
||||
@ -1086,15 +1087,17 @@ impl InnerResolver {
|
||||
};
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let ndarray_llvm_ty = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
|
||||
let ndarray_dtype_llvm_ty = ndarray_llvm_ty.element_type();
|
||||
let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
|
||||
let dtype = llvm_ndarray.element_type();
|
||||
|
||||
{
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(
|
||||
ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(),
|
||||
llvm_ndarray.as_base_type().get_element_type().into_struct_type(),
|
||||
Some(AddressSpace::default()),
|
||||
&id_str,
|
||||
)
|
||||
@ -1104,28 +1107,41 @@ impl InnerResolver {
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
|
||||
let ndarray_ndims = ndarray_llvm_ty.ndims().unwrap();
|
||||
let ndims = llvm_ndarray.ndims().unwrap();
|
||||
|
||||
// Obtain the shape of the ndarray
|
||||
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
||||
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
|
||||
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
|
||||
assert_eq!(shape_tuple.len(), ndims as usize);
|
||||
|
||||
// The Rust type inferencer cannot figure this out
|
||||
let shape_values = shape_tuple
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, elem)| {
|
||||
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|
||||
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
|
||||
)
|
||||
let value = self
|
||||
.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize())
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||
})?
|
||||
.unwrap();
|
||||
let value = value.into_int_value();
|
||||
Ok(value)
|
||||
})
|
||||
.collect();
|
||||
let shape_values = shape_values?.unwrap();
|
||||
let shape_values = llvm_usize.const_array(
|
||||
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
|
||||
);
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
// Also use this opportunity to get the constant values of `shape_values` for calculating strides.
|
||||
let shape_u64s = shape_values
|
||||
.iter()
|
||||
.map(|dim| {
|
||||
assert!(dim.is_const());
|
||||
dim.get_zero_extended_constant().unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let shape_values = llvm_usize.const_array(&shape_values);
|
||||
|
||||
// create a global for ndarray.shape and initialize it using the shape
|
||||
let shape_global = ctx.module.add_global(
|
||||
llvm_usize.array_type(ndarray_ndims as u32),
|
||||
llvm_usize.array_type(ndims as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&(id_str.clone() + ".shape"),
|
||||
);
|
||||
@ -1133,17 +1149,25 @@ impl InnerResolver {
|
||||
|
||||
// Obtain the (flattened) elements of the ndarray
|
||||
let sz: usize = obj.getattr("size")?.extract()?;
|
||||
let data: Result<Option<Vec<_>>, _> = (0..sz)
|
||||
let data: Vec<_> = (0..sz)
|
||||
.map(|i| {
|
||||
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
||||
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
|
||||
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||
})
|
||||
let value = self
|
||||
.get_obj_value(py, elem, ctx, generator, ndarray_dtype)
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!(
|
||||
"Error getting element {i}: {e}"
|
||||
))
|
||||
})?
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(value.get_type(), dtype);
|
||||
Ok(value)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let data = data?.unwrap().into_iter();
|
||||
let data = match ndarray_dtype_llvm_ty {
|
||||
.try_collect()?;
|
||||
let data = data.into_iter();
|
||||
let data = match dtype {
|
||||
BasicTypeEnum::ArrayType(ty) => {
|
||||
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
|
||||
}
|
||||
@ -1168,38 +1192,68 @@ impl InnerResolver {
|
||||
};
|
||||
|
||||
// create a global for ndarray.data and initialize it using the elements
|
||||
//
|
||||
// NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`.
|
||||
// We will have to cast it to an `u8*` later.
|
||||
let data_global = ctx.module.add_global(
|
||||
ndarray_dtype_llvm_ty.array_type(sz as u32),
|
||||
dtype.array_type(sz as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&(id_str.clone() + ".data"),
|
||||
);
|
||||
data_global.set_initializer(&data);
|
||||
|
||||
// Get the constant itemsize.
|
||||
let itemsize = dtype.size_of().unwrap();
|
||||
let itemsize = itemsize.get_zero_extended_constant().unwrap();
|
||||
|
||||
// Create the strides needed for ndarray.strides
|
||||
let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s);
|
||||
let strides =
|
||||
strides.into_iter().map(|stride| llvm_usize.const_int(stride, false)).collect_vec();
|
||||
let strides = llvm_usize.const_array(&strides);
|
||||
|
||||
// create a global for ndarray.strides and initialize it
|
||||
let strides_global = ctx.module.add_global(
|
||||
llvm_i8.array_type(ndims as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&format!("${id_str}.strides"),
|
||||
);
|
||||
strides_global.set_initializer(&strides);
|
||||
|
||||
// create a global for the ndarray object and initialize it
|
||||
let value = ndarray_llvm_ty
|
||||
|
||||
// NOTE: data_global is an array of dtype, we want a `u8*`.
|
||||
let ndarray_data = data_global.as_pointer_value();
|
||||
let ndarray_data = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
||||
|
||||
let ndarray_itemsize = llvm_usize.const_int(itemsize, false);
|
||||
|
||||
let ndarray_ndims = llvm_usize.const_int(ndims, false);
|
||||
|
||||
let ndarray_shape = shape_global.as_pointer_value();
|
||||
|
||||
let ndarray_strides = strides_global.as_pointer_value();
|
||||
|
||||
let ndarray = llvm_ndarray
|
||||
.as_base_type()
|
||||
.get_element_type()
|
||||
.into_struct_type()
|
||||
.const_named_struct(&[
|
||||
llvm_usize.const_int(ndarray_ndims, false).into(),
|
||||
shape_global
|
||||
.as_pointer_value()
|
||||
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
|
||||
.into(),
|
||||
data_global
|
||||
.as_pointer_value()
|
||||
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
|
||||
.into(),
|
||||
ndarray_itemsize.into(),
|
||||
ndarray_ndims.into(),
|
||||
ndarray_shape.into(),
|
||||
ndarray_strides.into(),
|
||||
ndarray_data.into(),
|
||||
]);
|
||||
|
||||
let ndarray = ctx.module.add_global(
|
||||
ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(),
|
||||
let ndarray_global = ctx.module.add_global(
|
||||
llvm_ndarray.as_base_type().get_element_type().into_struct_type(),
|
||||
Some(AddressSpace::default()),
|
||||
&id_str,
|
||||
);
|
||||
ndarray.set_initializer(&value);
|
||||
ndarray_global.set_initializer(&ndarray);
|
||||
|
||||
Ok(Some(ndarray.as_pointer_value().into()))
|
||||
Ok(Some(ndarray_global.as_pointer_value().into()))
|
||||
} else if ty_id == self.primitive_ids.tuple {
|
||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
|
||||
|
@ -867,3 +867,18 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
|
||||
for NDArrayDataProxy<'ctx, '_>
|
||||
{
|
||||
}
|
||||
|
||||
/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust.
|
||||
///
|
||||
/// This function is used generating strides for globally defined contiguous ndarrays.
|
||||
#[must_use]
|
||||
pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<u64> {
|
||||
let mut strides = Vec::with_capacity(ndims as usize);
|
||||
let mut stride_product = 1u64;
|
||||
for i in 0..ndims {
|
||||
let axis = ndims - i - 1;
|
||||
strides[axis as usize] = stride_product * itemsize;
|
||||
stride_product *= shape[axis as usize];
|
||||
}
|
||||
strides
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user