From 0452e6de786bb8dcb69ad99adfc85ac9b9f79e39 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 20 Jun 2024 12:24:26 +0800 Subject: [PATCH] core: Fix codegen for tuple-index into ndarray --- nac3core/src/codegen/expr.rs | 305 +++++++++++++++++++++-------------- 1 file changed, 181 insertions(+), 124 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index d5075485..3c5d0a3d 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3,8 +3,8 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue, - RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue, + TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, @@ -1741,22 +1741,37 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let ndims = values .iter() - .map(|ndim| match *ndim { - SymbolValue::U64(v) => Ok(v), - SymbolValue::U32(v) => Ok(u64::from(v)), - SymbolValue::I32(v) => u64::try_from(v) - .map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")), - SymbolValue::I64(v) => u64::try_from(v) - .map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")), - _ => unreachable!(), - }) - .collect::, _>>()?; + .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) + .collect::, _>>() + .map_err(|val| { + format!( + "Expected non-negative literal for ndarray.ndims, got {}", + i128::try_from(val).unwrap() + ) + })?; assert!(!ndims.is_empty()); - let ndarray_ndims_ty = ctx - .unifier - .get_fresh_literal(ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), None); + // The number of dimensions subscripted by the index expression. + // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a + // dimension will remove a dimension. + let subscripted_dims = match &slice.node { + ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { + if let ExprKind::Slice { .. } = &value_subexpr.node { + acc + } else { + acc + 1 + } + }), + + ExprKind::Slice { .. } => 0, + _ => 1, + }; + + let ndarray_ndims_ty = ctx.unifier.get_fresh_literal( + ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(), + None, + ); let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); @@ -1859,123 +1874,165 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( } }; - Ok(Some(match &slice.node { - ExprKind::Tuple { elts, .. } => { - let slices = elts - .iter() - .enumerate() - .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) - .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) - .collect::, _>>()?; - if slices.len() < elts.len() { - return Ok(None); - } - - let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() - } - - ExprKind::Slice { .. } => { - let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None); - }; - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() - } - - _ => { - let index = if let Some(index) = generator.gen_expr(ctx, slice)? { - index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() - } else { - return Ok(None); - }; - let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; - let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; - ctx.builder.build_store(index_addr, index).unwrap(); - - if ndims.len() == 1 && ndims[0] == 1 { - // Accessing an element from a 1-dimensional `ndarray` - - return Ok(Some( - v.data() - .get( - ctx, - generator, - &ArraySliceValue::from_ptr_val( - index_addr, - llvm_usize.const_int(1, false), - None, - ), - None, - ) - .into(), - )); - } - - // Accessing an element from a multi-dimensional `ndarray` - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); - - let num_dims = v.load_ndims(ctx); - ndarray.store_ndims( + let make_indices_arr = |generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>| + -> Result<_, String> { + Ok(if let ExprKind::Tuple { elts, .. } = &slice.node { + let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap()); + let index_addr = generator.gen_array_var_alloc( ctx, - generator, - ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(), - ); + llvm_int_ty, + llvm_usize.const_int(elts.len() as u64, false), + None, + )?; - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + for (i, elt) in elts.iter().enumerate() { + let Some(index) = generator.gen_expr(ctx, elt)? else { + return Ok(None); + }; - let ndarray_num_dims = ndarray.load_ndims(ctx); - let v_dims_src_ptr = unsafe { - v.dim_sizes().ptr_offset_unchecked( + let index = index + .to_basic_value_enum(ctx, generator, elt.custom.unwrap())? + .into_int_value(); + let Some(index) = normalize_index(generator, ctx, index, 0)? else { + return Ok(None); + }; + + let store_ptr = unsafe { + index_addr.ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(i as u64, false), + None, + ) + }; + ctx.builder.build_store(store_ptr, index).unwrap(); + } + + Some(index_addr) + } else if let Some(index) = generator.gen_expr(ctx, slice)? { + let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap()); + let index_addr = generator.gen_array_var_alloc( + ctx, + llvm_int_ty, + llvm_usize.const_int(1u64, false), + None, + )?; + + let index = + index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); + let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; + + let store_ptr = unsafe { + index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + }; + ctx.builder.build_store(store_ptr, index).unwrap(); + + Some(index_addr) + } else { + None + }) + }; + + Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { + let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; + + v.data().get(ctx, generator, &index_addr, None).into() + } else { + match &slice.node { + ExprKind::Tuple { elts, .. } => { + let slices = elts + .iter() + .enumerate() + .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) + .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) + .collect::, _>>()?; + if slices.len() < elts.len() { + return Ok(None); + } + + let slices = slices.into_iter().map(Option::unwrap).collect_vec(); + + numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() + } + + ExprKind::Slice { .. } => { + let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { + return Ok(None); + }; + + numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() + } + + _ => { + // Accessing an element from a multi-dimensional `ndarray` + + let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; + + // Create a new array, remove the top dimension from the dimension-size-list, and copy the + // elements over + let subscripted_ndarray = + generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; + let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); + + let num_dims = v.load_ndims(ctx); + ndarray.store_ndims( ctx, generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); + ctx.builder + .build_int_sub(num_dims, llvm_usize.const_int(1, false), "") + .unwrap(), + ); - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - let v_data_src_ptr = v.data().ptr_offset( - ctx, - generator, - &ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None), - None, - ); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); + let ndarray_num_dims = ndarray.load_ndims(ctx); + let v_dims_src_ptr = unsafe { + v.dim_sizes().ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ) + }; + call_memcpy_generic( + ctx, + ndarray.dim_sizes().base_ptr(ctx, generator), + v_dims_src_ptr, + ctx.builder + .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") + .map(Into::into) + .unwrap(), + llvm_i1.const_zero(), + ); - ndarray.as_base_value().into() + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + &ndarray.dim_sizes().as_slice_value(ctx, generator), + (None, None), + ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); + + let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); + call_memcpy_generic( + ctx, + ndarray.data().base_ptr(ctx, generator), + v_data_src_ptr, + ctx.builder + .build_int_mul( + ndarray_num_elems, + llvm_ndarray_data_t.size_of().unwrap(), + "", + ) + .map(Into::into) + .unwrap(), + llvm_i1.const_zero(), + ); + + ndarray.as_base_value().into() + } } })) }