From ed79d5bb9e155691d4e2eb1403b7c6934b02fb4b Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 30 May 2024 16:08:15 +0800 Subject: [PATCH] core/expr: Add support for multi-dim slicing of NDArrays --- nac3core/src/codegen/expr.rs | 297 +++++++++++------- nac3core/src/typecheck/type_inferencer/mod.rs | 30 +- 2 files changed, 216 insertions(+), 111 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 674d7840..f24c7d78 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1667,6 +1667,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( slice: &Expr>, ) -> Result>, String> { let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { @@ -1712,32 +1713,11 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( slice.location, ); - if let ExprKind::Slice { lower, upper, step } = &slice.node { - let dim0_sz = unsafe { - v.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let Some((start, stop, step)) = handle_slice_indices( - lower, - upper, - step, - ctx, - generator, - dim0_sz, - )? else { return Ok(None) }; - - return Ok(Some(numpy::ndarray_sliced_copy( - generator, - ctx, - ty, - v, - &[(start, stop, step)], - )?.as_ptr_value().into())) - } - - let index = if let Some(index) = generator.gen_expr(ctx, slice)? { - let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); - + // Normalizes a possibly-negative index to its corresponding positive index + let normalize_index = |generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + index: IntValue<'ctx>, + dim: u64| { gen_if_else_expr_callback( generator, ctx, @@ -1757,7 +1737,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( v.dim_sizes().get_typed_unchecked( ctx, generator, - &llvm_usize.const_zero(), + &llvm_usize.const_int(dim, true), None, ) }; @@ -1770,97 +1750,194 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) }, - )?.map(BasicValueEnum::into_int_value).unwrap() - } else { - return Ok(None) + ).map(|v| v.map(BasicValueEnum::into_int_value)) }; - let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; - ctx.builder.build_store(index_addr, index).unwrap(); - if ndims.len() == 1 && ndims[0] == 1 { - // Accessing an element from a 1-dimensional `ndarray` + // Converts a slice expression into a slice-range tuple + let expr_to_slice = |generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + node: &ExprKind>, + dim: u64| { + match node { + ExprKind::Constant { value: Constant::Int(v), .. } => { + let Some(index) = normalize_index( + generator, ctx, llvm_i32.const_int(*v as u64, true), dim, + )? else { + return Ok(None) + }; - Ok(Some(v.data() - .get( + Ok(Some((index, index, llvm_i32.const_int(1, true)))) + } + + ExprKind::Slice { lower, upper, step } => { + let dim_sz = unsafe { + v.dim_sizes() + .get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(dim, false), + None, + ) + }; + + handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) + } + + _ => { + let Some(index) = generator.gen_expr(ctx, slice)? else { + return Ok(None) + }; + let index = index + .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? + .into_int_value(); + let Some(index) = normalize_index(generator, ctx, index, dim)? else { + return Ok(None) + }; + + Ok(Some((index, index, llvm_i32.const_int(1, true)))) + } + } + }; + + Ok(Some(match &slice.node { + ExprKind::Tuple { elts, .. } => { + let slices = elts.iter().enumerate() + .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) + .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) + .collect::, _>>()?; + if slices.len() < elts.len() { + return Ok(None) + } + + let slices = slices.into_iter() + .map(Option::unwrap) + .collect_vec(); + + numpy::ndarray_sliced_copy( + generator, + ctx, + ty, + v, + &slices, + )?.as_ptr_value().into() + } + + ExprKind::Slice { .. } => { + let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { + return Ok(None) + }; + + numpy::ndarray_sliced_copy( + generator, + ctx, + ty, + v, + &[slice], + )?.as_ptr_value().into() + } + + _ => { + let index = if let Some(index) = generator.gen_expr(ctx, slice)? { + index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() + } else { + return Ok(None) + }; + let Some(index) = normalize_index(generator, ctx, index, 0)? else { + return Ok(None) + }; + let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; + ctx.builder.build_store(index_addr, index).unwrap(); + + if ndims.len() == 1 && ndims[0] == 1 { + // Accessing an element from a 1-dimensional `ndarray` + + return Ok(Some(v.data() + .get( + ctx, + generator, + &ArraySliceValue::from_ptr_val( + index_addr, + llvm_usize.const_int(1, false), + None, + ), + None, + ) + .into())) + } + + // Accessing an element from a multi-dimensional `ndarray` + + // Create a new array, remove the top dimension from the dimension-size-list, and copy the + // elements over + let subscripted_ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None + )?; + let ndarray = NDArrayValue::from_ptr_val( + subscripted_ndarray, + llvm_usize, + None + ); + + let num_dims = v.load_ndims(ctx); + ndarray.store_ndims( + ctx, + generator, + ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(), + ); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let v_dims_src_ptr = unsafe { + v.dim_sizes().ptr_offset_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ) + }; + call_memcpy_generic( + ctx, + ndarray.dim_sizes().base_ptr(ctx, generator), + v_dims_src_ptr, + ctx.builder + .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") + .map(Into::into) + .unwrap(), + llvm_i1.const_zero(), + ); + + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + &ndarray.dim_sizes().as_slice_value(ctx, generator), + (None, None), + ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); + + let v_data_src_ptr = v.data().ptr_offset( ctx, generator, &ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None), - None, - ) - .into())) - } else { - // Accessing an element from a multi-dimensional `ndarray` - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let subscripted_ndarray = generator.gen_var_alloc( - ctx, - llvm_ndarray_t.into(), - None - )?; - let ndarray = NDArrayValue::from_ptr_val( - subscripted_ndarray, - llvm_usize, - None - ); - - let num_dims = v.load_ndims(ctx); - ndarray.store_ndims( - ctx, - generator, - ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(), - ); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let v_dims_src_ptr = unsafe { - v.dim_sizes().ptr_offset_unchecked( + None + ); + call_memcpy_generic( ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); + ndarray.data().base_ptr(ctx, generator), + v_data_src_ptr, + ctx.builder + .build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "") + .map(Into::into) + .unwrap(), + llvm_i1.const_zero(), + ); - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - - let v_data_src_ptr = v.data().ptr_offset( - ctx, - generator, - &ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None), - None - ); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - Ok(Some(ndarray.as_ptr_value().into())) - } + ndarray.as_ptr_value().into() + } + })) } /// See [`CodeGenerator::gen_expr`]. diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c26366d8..88b27f01 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -2,6 +2,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::{From, TryInto}; use std::iter::once; use std::{cell::RefCell, sync::Arc}; +use std::ops::Not; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; @@ -554,7 +555,10 @@ impl<'a> Fold<()> for Inferencer<'a> { ExprKind::ListComp { .. } | ExprKind::Lambda { .. } | ExprKind::Call { .. } => expr.custom, // already computed - ExprKind::Slice { .. } => None, // we don't need it for slice + ExprKind::Slice { .. } => { + // slices aren't exactly ranges, but for our purposes this should suffice + Some(self.primitives.range) + } _ => return report_error("not supported", expr.location), }; Ok(ast::Expr { custom, location: expr.location, node: expr.node }) @@ -1642,6 +1646,30 @@ impl<'a> Inferencer<'a> { } } } + ExprKind::Tuple { elts, .. } => { + if value.custom + .unwrap() + .obj_id(self.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) + .not() { + return report_error("Tuple slices are only supported for ndarrays", slice.location) + } + + for elt in elts { + if let ExprKind::Slice { lower, upper, step } = &elt.node { + for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { + self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; + } + } else { + self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?; + } + } + + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + let ndarray_ty = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); + self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?; + Ok(ndarray_ty) + } _ => { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)