From c9c9dae91b2fc4c9b76fd9ab34ce01125d9b7bf1 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 22 Aug 2024 16:19:09 +0800 Subject: [PATCH] artiq: reimplement get_obj_value to use ndarray with strides --- nac3artiq/src/symbol_resolver.rs | 170 +++++++++++++-------- nac3core/src/codegen/object/ndarray/mod.rs | 15 ++ 2 files changed, 118 insertions(+), 67 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 9470ee7..21b1ac1 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1,14 +1,15 @@ use crate::PrimitivePythonId; use inkwell::{ module::Linkage, - types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + types::BasicType, + values::{BasicValue, BasicValueEnum}, AddressSpace, }; use itertools::Itertools; use nac3core::{ codegen::{ - classes::{NDArrayType, ProxyType}, + model::*, + object::ndarray::{make_contiguous_strides, NDArray}, CodeGenContext, CodeGenerator, }, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, @@ -26,7 +27,7 @@ use nac3parser::ast::{self, StrRef}; use parking_lot::{Mutex, RwLock}; use pyo3::{ types::{PyDict, PyTuple}, - PyAny, PyObject, PyResult, Python, + PyAny, PyErr, PyObject, PyResult, Python, }; use std::{ collections::{HashMap, HashSet}, @@ -1086,15 +1087,12 @@ impl InnerResolver { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); - let llvm_usize = generator.get_size_type(ctx.ctx); - let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); - let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); - + let dtype = Any(ctx.get_llvm_type(generator, ndarray_dtype)); { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + Struct(NDArray).get_type(generator, ctx.ctx), Some(AddressSpace::default()), &id_str, ) @@ -1114,100 +1112,138 @@ impl InnerResolver { } else { todo!("Unpacking literal of more than one element unimplemented") }; - let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { + let Ok(ndims) = u64::try_from(ndarray_ndims) else { unreachable!("Expected u64 value for ndarray_ndims") }; // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; - assert_eq!(shape_tuple.len(), ndarray_ndims as usize); - let shape_values: Result>, _> = shape_tuple + assert_eq!(shape_tuple.len(), ndims as usize); + + // The Rust type inferencer cannot figure this out + let shape_values: Result>>, PyErr> = shape_tuple .iter() .enumerate() .map(|(i, elem)| { - self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( - |e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), - ) + let value = self + .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + })? + .unwrap(); + let value = Int(SizeT).check_value(generator, ctx.ctx, value).unwrap(); + Ok(value) }) .collect(); - let shape_values = shape_values?.unwrap(); - let shape_values = llvm_usize.const_array( - &shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), - ); + let shape_values = shape_values?; + + // Also use this opportunity to get the constant values of `shape_values` for calculating strides. + let shape_u64s = shape_values + .iter() + .map(|dim| { + assert!(dim.value.is_const()); + dim.value.get_zero_extended_constant().unwrap() + }) + .collect_vec(); + let shape_values = Int(SizeT).const_array(generator, ctx.ctx, &shape_values); // create a global for ndarray.shape and initialize it using the shape let shape_global = ctx.module.add_global( - llvm_usize.array_type(ndarray_ndims as u32), + Array { len: AnyLen(ndims as u32), item: Int(SizeT) }.get_type(generator, ctx.ctx), Some(AddressSpace::default()), &(id_str.clone() + ".shape"), ); - shape_global.set_initializer(&shape_values); + shape_global.set_initializer(&shape_values.value); // Obtain the (flattened) elements of the ndarray let sz: usize = obj.getattr("size")?.extract()?; - let data: Result>, _> = (0..sz) + let data_values: Vec> = (0..sz) .map(|i| { obj.getattr("flat")?.get_item(i).and_then(|elem| { - self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { - super::CompileError::new_err(format!("Error getting element {i}: {e}")) - }) + let value = self + .get_obj_value(py, elem, ctx, generator, ndarray_dtype) + .map_err(|e| { + super::CompileError::new_err(format!( + "Error getting element {i}: {e}" + )) + })? + .unwrap(); + + let value = dtype.check_value(generator, ctx.ctx, value).unwrap(); + Ok(value) }) }) - .collect(); - let data = data?.unwrap().into_iter(); - let data = match ndarray_dtype_llvm_ty { - BasicTypeEnum::ArrayType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) - } - - BasicTypeEnum::FloatType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) - } - - BasicTypeEnum::IntType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) - } - - BasicTypeEnum::PointerType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) - } - - BasicTypeEnum::StructType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) - } - - BasicTypeEnum::VectorType(_) => unreachable!(), - }; + .try_collect()?; + let data = dtype.const_array(generator, ctx.ctx, &data_values); // create a global for ndarray.data and initialize it using the elements + // + // NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`. + // We will have to cast it to an `u8*` later. let data_global = ctx.module.add_global( - ndarray_dtype_llvm_ty.array_type(sz as u32), + Array { len: AnyLen(sz as u32), item: dtype }.get_type(generator, ctx.ctx), Some(AddressSpace::default()), &(id_str.clone() + ".data"), ); - data_global.set_initializer(&data); + data_global.set_initializer(&data.value); + + // Get the constant itemsize. + let itemsize = dtype.get_type(generator, ctx.ctx).size_of().unwrap(); + let itemsize = itemsize.get_zero_extended_constant().unwrap(); + + // Create the strides needed for ndarray.strides + let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); + let strides = strides + .into_iter() + .map(|stride| Int(SizeT).const_int(generator, ctx.ctx, stride)) + .collect_vec(); + let strides = Int(SizeT).const_array(generator, ctx.ctx, &strides); + + // create a global for ndarray.strides and initialize it + let strides_global = ctx.module.add_global( + Array { len: AnyLen(ndims as u32), item: Int(Byte) }.get_type(generator, ctx.ctx), + Some(AddressSpace::default()), + &(id_str.clone() + ".strides"), + ); + strides_global.set_initializer(&strides.value); // create a global for the ndarray object and initialize it - let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ - llvm_usize.const_int(ndarray_ndims, false).into(), - shape_global - .as_pointer_value() - .const_cast(llvm_usize.ptr_type(AddressSpace::default())) - .into(), - data_global - .as_pointer_value() - .const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default())) - .into(), - ]); + // We are also doing [`Model::check_value`] instead of [`Model::believe_value`] to catch bugs. - let ndarray = ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + // NOTE: data_global is an array of dtype, we want a `u8*`. + let ndarray_data = Ptr(dtype).check_value(generator, ctx.ctx, data_global).unwrap(); + let ndarray_data = Ptr(Int(Byte)).pointer_cast(generator, ctx, ndarray_data.value); + + let ndarray_itemsize = Int(SizeT).const_int(generator, ctx.ctx, itemsize); + + let ndarray_ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims); + + let ndarray_shape = + Ptr(Int(SizeT)).check_value(generator, ctx.ctx, shape_global).unwrap(); + + let ndarray_strides = + Ptr(Int(SizeT)).check_value(generator, ctx.ctx, strides_global).unwrap(); + + let ndarray = Struct(NDArray).const_struct( + generator, + ctx.ctx, + &[ + ndarray_data.value.as_basic_value_enum(), + ndarray_itemsize.value.as_basic_value_enum(), + ndarray_ndims.value.as_basic_value_enum(), + ndarray_shape.value.as_basic_value_enum(), + ndarray_strides.value.as_basic_value_enum(), + ], + ); + + let ndarray_global = ctx.module.add_global( + Struct(NDArray).get_type(generator, ctx.ctx), Some(AddressSpace::default()), &id_str, ); - ndarray.set_initializer(&value); + ndarray_global.set_initializer(&ndarray.value); - Ok(Some(ndarray.as_pointer_value().into())) + Ok(Some(ndarray_global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index ce8196d..91a0d7b 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -651,3 +651,18 @@ impl<'ctx> NDArrayOut<'ctx> { } } } + +/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust. +/// +/// This function is used generating strides for globally defined contiguous ndarrays. +#[must_use] +pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec { + let mut strides = Vec::with_capacity(ndims as usize); + let mut stride_product = 1u64; + for i in 0..ndims { + let axis = ndims - i - 1; + strides[axis as usize] = stride_product * itemsize; + stride_product *= shape[axis as usize]; + } + strides +}