diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 2b236102..0bb8893c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -10,13 +10,14 @@ use itertools::Itertools; use parking_lot::RwLock; use pyo3::{ types::{PyDict, PyTuple}, - PyAny, PyObject, PyResult, Python, + PyAny, PyErr, PyObject, PyResult, Python, }; use super::PrimitivePythonId; use nac3core::{ codegen::{ types::{NDArrayType, ProxyType}, + values::make_contiguous_strides, CodeGenContext, CodeGenerator, }, inkwell::{ @@ -28,7 +29,7 @@ use nac3core::{ nac3parser::ast::{self, StrRef}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{ - helper::{extract_ndims, PrimDef}, + helper::PrimDef, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, TopLevelDef, }, @@ -1087,20 +1088,17 @@ impl InnerResolver { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_usize = generator.get_size_type(ctx.ctx); - let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); - let ndarray_llvm_ty = NDArrayType::new( - generator, - ctx.ctx, - ndarray_dtype_llvm_ty, - Some(extract_ndims(&ctx.unifier, ndarray_ndims)), - ); + let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); + let dtype = llvm_ndarray.element_type(); { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(), + llvm_ndarray.as_base_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ) @@ -1120,30 +1118,43 @@ impl InnerResolver { } else { todo!("Unpacking literal of more than one element unimplemented") }; - let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { + let Ok(ndims) = u64::try_from(ndarray_ndims) else { unreachable!("Expected u64 value for ndarray_ndims") }; // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; - assert_eq!(shape_tuple.len(), ndarray_ndims as usize); - let shape_values: Result>, _> = shape_tuple + assert_eq!(shape_tuple.len(), ndims as usize); + + // The Rust type inferencer cannot figure this out + let shape_values = shape_tuple .iter() .enumerate() .map(|(i, elem)| { - self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( - |e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), - ) + let value = self + .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + })? + .unwrap(); + let value = value.into_int_value(); + Ok(value) }) - .collect(); - let shape_values = shape_values?.unwrap(); - let shape_values = llvm_usize.const_array( - &shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), - ); + .collect::, PyErr>>()?; + + // Also use this opportunity to get the constant values of `shape_values` for calculating strides. + let shape_u64s = shape_values + .iter() + .map(|dim| { + assert!(dim.is_const()); + dim.get_zero_extended_constant().unwrap() + }) + .collect_vec(); + let shape_values = llvm_usize.const_array(&shape_values); // create a global for ndarray.shape and initialize it using the shape let shape_global = ctx.module.add_global( - llvm_usize.array_type(ndarray_ndims as u32), + llvm_usize.array_type(ndims as u32), Some(AddressSpace::default()), &(id_str.clone() + ".shape"), ); @@ -1151,17 +1162,25 @@ impl InnerResolver { // Obtain the (flattened) elements of the ndarray let sz: usize = obj.getattr("size")?.extract()?; - let data: Result>, _> = (0..sz) + let data: Vec<_> = (0..sz) .map(|i| { obj.getattr("flat")?.get_item(i).and_then(|elem| { - self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { - super::CompileError::new_err(format!("Error getting element {i}: {e}")) - }) + let value = self + .get_obj_value(py, elem, ctx, generator, ndarray_dtype) + .map_err(|e| { + super::CompileError::new_err(format!( + "Error getting element {i}: {e}" + )) + })? + .unwrap(); + + assert_eq!(value.get_type(), dtype); + Ok(value) }) }) - .collect(); - let data = data?.unwrap().into_iter(); - let data = match ndarray_dtype_llvm_ty { + .try_collect()?; + let data = data.into_iter(); + let data = match dtype { BasicTypeEnum::ArrayType(ty) => { ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) } @@ -1186,38 +1205,68 @@ impl InnerResolver { }; // create a global for ndarray.data and initialize it using the elements + // + // NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`. + // We will have to cast it to an `u8*` later. let data_global = ctx.module.add_global( - ndarray_dtype_llvm_ty.array_type(sz as u32), + dtype.array_type(sz as u32), Some(AddressSpace::default()), &(id_str.clone() + ".data"), ); data_global.set_initializer(&data); + // Get the constant itemsize. + let itemsize = dtype.size_of().unwrap(); + let itemsize = itemsize.get_zero_extended_constant().unwrap(); + + // Create the strides needed for ndarray.strides + let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); + let strides = + strides.into_iter().map(|stride| llvm_usize.const_int(stride, false)).collect_vec(); + let strides = llvm_usize.const_array(&strides); + + // create a global for ndarray.strides and initialize it + let strides_global = ctx.module.add_global( + llvm_i8.array_type(ndims as u32), + Some(AddressSpace::default()), + &format!("${id_str}.strides"), + ); + strides_global.set_initializer(&strides); + // create a global for the ndarray object and initialize it - let value = ndarray_llvm_ty + + // NOTE: data_global is an array of dtype, we want a `u8*`. + let ndarray_data = data_global.as_pointer_value(); + let ndarray_data = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap(); + + let ndarray_itemsize = llvm_usize.const_int(itemsize, false); + + let ndarray_ndims = llvm_usize.const_int(ndims, false); + + let ndarray_shape = shape_global.as_pointer_value(); + + let ndarray_strides = strides_global.as_pointer_value(); + + let ndarray = llvm_ndarray .as_base_type() .get_element_type() .into_struct_type() .const_named_struct(&[ - llvm_usize.const_int(ndarray_ndims, false).into(), - shape_global - .as_pointer_value() - .const_cast(llvm_usize.ptr_type(AddressSpace::default())) - .into(), - data_global - .as_pointer_value() - .const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default())) - .into(), + ndarray_itemsize.into(), + ndarray_ndims.into(), + ndarray_shape.into(), + ndarray_strides.into(), + ndarray_data.into(), ]); - let ndarray = ctx.module.add_global( - ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(), + let ndarray_global = ctx.module.add_global( + llvm_ndarray.as_base_type().get_element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ); - ndarray.set_initializer(&value); + ndarray_global.set_initializer(&ndarray); - Ok(Some(ndarray.as_pointer_value().into())) + Ok(Some(ndarray_global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 65e3313f..3fda45d4 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -881,3 +881,18 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, for NDArrayDataProxy<'ctx, '_> { } + +/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust. +/// +/// This function is used generating strides for globally defined contiguous ndarrays. +#[must_use] +pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec { + let mut strides = Vec::with_capacity(ndims as usize); + let mut stride_product = 1u64; + for i in 0..ndims { + let axis = ndims - i - 1; + strides[axis as usize] = stride_product * itemsize; + stride_product *= shape[axis as usize]; + } + strides +}