artiq: reimplement get_obj_value to use ndarray with strides

This commit is contained in:
lyken 2024-08-22 16:19:09 +08:00
parent bb512b8f57
commit 1f0463463f
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
2 changed files with 118 additions and 67 deletions

View File

@ -1,14 +1,15 @@
use crate::PrimitivePythonId; use crate::PrimitivePythonId;
use inkwell::{ use inkwell::{
module::Linkage, module::Linkage,
types::{BasicType, BasicTypeEnum}, types::BasicType,
values::BasicValueEnum, values::{BasicValue, BasicValueEnum},
AddressSpace, AddressSpace,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{NDArrayType, ProxyType}, model::*,
object::ndarray::{make_contiguous_strides, NDArray},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
@ -26,7 +27,7 @@ use nac3parser::ast::{self, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
use pyo3::{ use pyo3::{
types::{PyDict, PyTuple}, types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python, PyAny, PyErr, PyObject, PyResult, Python,
}; };
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@ -1085,15 +1086,12 @@ impl InnerResolver {
let (ndarray_dtype, ndarray_ndims) = let (ndarray_dtype, ndarray_ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let llvm_usize = generator.get_size_type(ctx.ctx); let dtype = Any(ctx.get_llvm_type(generator, ndarray_dtype));
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
{ {
if self.global_value_ids.read().contains_key(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global( ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(), Struct(NDArray).llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
) )
@ -1113,100 +1111,138 @@ impl InnerResolver {
} else { } else {
todo!("Unpacking literal of more than one element unimplemented") todo!("Unpacking literal of more than one element unimplemented")
}; };
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { let Ok(ndims) = u64::try_from(ndarray_ndims) else {
unreachable!("Expected u64 value for ndarray_ndims") unreachable!("Expected u64 value for ndarray_ndims")
}; };
// Obtain the shape of the ndarray // Obtain the shape of the ndarray
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
assert_eq!(shape_tuple.len(), ndarray_ndims as usize); assert_eq!(shape_tuple.len(), ndims as usize);
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
// The Rust type inferencer cannot figure this out
let shape_values: Result<Vec<Instance<'ctx, Int<SizeT>>>, PyErr> = shape_tuple
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, elem)| { .map(|(i, elem)| {
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( let value = self
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize())
) .map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})?
.unwrap();
let value = Int(SizeT).check_value(generator, ctx.ctx, value).unwrap();
Ok(value)
}) })
.collect(); .collect();
let shape_values = shape_values?.unwrap(); let shape_values = shape_values?;
let shape_values = llvm_usize.const_array(
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), // Also use this opportunity to get the constant values of `shape_values` for calculating strides.
); let shape_u64s = shape_values
.iter()
.map(|dim| {
assert!(dim.value.is_const());
dim.value.get_zero_extended_constant().unwrap()
})
.collect_vec();
let shape_values = Int(SizeT).const_array(generator, ctx.ctx, &shape_values);
// create a global for ndarray.shape and initialize it using the shape // create a global for ndarray.shape and initialize it using the shape
let shape_global = ctx.module.add_global( let shape_global = ctx.module.add_global(
llvm_usize.array_type(ndarray_ndims as u32), Array { len: AnyLen(ndims as u32), item: Int(SizeT) }.llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&(id_str.clone() + ".shape"), &(id_str.clone() + ".shape"),
); );
shape_global.set_initializer(&shape_values); shape_global.set_initializer(&shape_values.value);
// Obtain the (flattened) elements of the ndarray // Obtain the (flattened) elements of the ndarray
let sz: usize = obj.getattr("size")?.extract()?; let sz: usize = obj.getattr("size")?.extract()?;
let data: Result<Option<Vec<_>>, _> = (0..sz) let data_values: Vec<Instance<'ctx, Any>> = (0..sz)
.map(|i| { .map(|i| {
obj.getattr("flat")?.get_item(i).and_then(|elem| { obj.getattr("flat")?.get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { let value = self
super::CompileError::new_err(format!("Error getting element {i}: {e}")) .get_obj_value(py, elem, ctx, generator, ndarray_dtype)
}) .map_err(|e| {
super::CompileError::new_err(format!(
"Error getting element {i}: {e}"
))
})?
.unwrap();
let value = dtype.check_value(generator, ctx.ctx, value).unwrap();
Ok(value)
}) })
}) })
.collect(); .try_collect()?;
let data = data?.unwrap().into_iter(); let data = dtype.const_array(generator, ctx.ctx, &data_values);
let data = match ndarray_dtype_llvm_ty {
BasicTypeEnum::ArrayType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
}
BasicTypeEnum::FloatType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec())
}
BasicTypeEnum::IntType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec())
}
BasicTypeEnum::PointerType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec())
}
BasicTypeEnum::StructType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec())
}
BasicTypeEnum::VectorType(_) => unreachable!(),
};
// create a global for ndarray.data and initialize it using the elements // create a global for ndarray.data and initialize it using the elements
//
// NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`.
// We will have to cast it to an `u8*` later.
let data_global = ctx.module.add_global( let data_global = ctx.module.add_global(
ndarray_dtype_llvm_ty.array_type(sz as u32), Array { len: AnyLen(sz as u32), item: dtype }.llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&(id_str.clone() + ".data"), &(id_str.clone() + ".data"),
); );
data_global.set_initializer(&data); data_global.set_initializer(&data.value);
// Get the constant itemsize.
let itemsize = dtype.llvm_type(generator, ctx.ctx).size_of().unwrap();
let itemsize = itemsize.get_zero_extended_constant().unwrap();
// Create the strides needed for ndarray.strides
let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s);
let strides = strides
.into_iter()
.map(|stride| Int(SizeT).const_int(generator, ctx.ctx, stride, false))
.collect_vec();
let strides = Int(SizeT).const_array(generator, ctx.ctx, &strides);
// create a global for ndarray.strides and initialize it
let strides_global = ctx.module.add_global(
Array { len: AnyLen(ndims as u32), item: Int(Byte) }.llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()),
&(id_str.clone() + ".strides"),
);
strides_global.set_initializer(&strides.value);
// create a global for the ndarray object and initialize it // create a global for the ndarray object and initialize it
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ // We are also doing [`Model::check_value`] instead of [`Model::believe_value`] to catch bugs.
llvm_usize.const_int(ndarray_ndims, false).into(),
shape_global
.as_pointer_value()
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
.into(),
data_global
.as_pointer_value()
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
.into(),
]);
let ndarray = ctx.module.add_global( // NOTE: data_global is an array of dtype, we want a `u8*`.
ndarray_llvm_ty.as_underlying_type(), let ndarray_data = Ptr(dtype).check_value(generator, ctx.ctx, data_global).unwrap();
let ndarray_data = Ptr(Int(Byte)).pointer_cast(generator, ctx, ndarray_data.value);
let ndarray_itemsize = Int(SizeT).const_int(generator, ctx.ctx, itemsize, false);
let ndarray_ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims, false);
let ndarray_shape =
Ptr(Int(SizeT)).check_value(generator, ctx.ctx, shape_global).unwrap();
let ndarray_strides =
Ptr(Int(SizeT)).check_value(generator, ctx.ctx, strides_global).unwrap();
let ndarray = Struct(NDArray).const_struct(
generator,
ctx.ctx,
&[
ndarray_data.value.as_basic_value_enum(),
ndarray_itemsize.value.as_basic_value_enum(),
ndarray_ndims.value.as_basic_value_enum(),
ndarray_shape.value.as_basic_value_enum(),
ndarray_strides.value.as_basic_value_enum(),
],
);
let ndarray_global = ctx.module.add_global(
Struct(NDArray).llvm_type(generator, ctx.ctx),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
); );
ndarray.set_initializer(&value); ndarray_global.set_initializer(&ndarray.value);
Ok(Some(ndarray.as_pointer_value().into())) Ok(Some(ndarray_global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {

View File

@ -653,3 +653,18 @@ impl<'ctx> NDArrayOut<'ctx> {
} }
} }
} }
/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust.
///
/// This function is used generating strides for globally defined contiguous ndarrays.
#[must_use]
pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<u64> {
let mut strides = Vec::with_capacity(ndims as usize);
let mut stride_product = 1u64;
for i in 0..ndims {
let axis = ndims - i - 1;
strides[axis as usize] = stride_product * itemsize;
stride_product *= shape[axis as usize];
}
strides
}