From e0f440040c5f15cc791299b8155e953c55060d36 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 15 Apr 2024 12:20:13 +0800 Subject: [PATCH] core/expr: Implement negative indices for ndarray --- nac3core/src/codegen/expr.rs | 73 ++++++++++++++++++++---------- nac3standalone/demo/src/ndarray.py | 8 ++++ 2 files changed, 56 insertions(+), 25 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 474052a72..b44e1a556 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -9,6 +9,7 @@ use crate::{ ListValue, NDArrayValue, RangeValue, + TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, @@ -18,7 +19,7 @@ use crate::{ irrt::*, llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, numpy, - stmt::{gen_raise, gen_var}, + stmt::{gen_if_else_expr_callback, gen_raise, gen_var}, CodeGenContext, CodeGenTask, }, symbol_resolver::{SymbolValue, ValueEnum}, @@ -1692,21 +1693,55 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( slice.location, ); + if let ExprKind::Slice { .. } = &slice.node { + return Err(String::from("subscript operator for ndarray not implemented")) + } + + let index = if let Some(index) = generator.gen_expr(ctx, slice)? { + let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); + + gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx.builder.build_int_compare( + IntPredicate::SGE, + index, + index.get_type().const_zero(), + "", + ).unwrap()) + }, + |_, _| Ok(Some(index)), + |generator, ctx| { + let llvm_i32 = ctx.ctx.i32_type(); + + let len = unsafe { + v.dim_sizes().get_typed_unchecked( + ctx, + generator, + llvm_usize.const_zero(), + None, + ) + }; + + let index = ctx.builder.build_int_add( + len, + ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), + "", + ).unwrap(); + + Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) + }, + )?.map(BasicValueEnum::into_int_value).unwrap() + } else { + return Ok(None) + }; + let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; + ctx.builder.build_store(index_addr, index).unwrap(); + if ndims.len() == 1 && ndims[0] == 1 { // Accessing an element from a 1-dimensional `ndarray` - if let ExprKind::Slice { .. } = &slice.node { - return Err(String::from("subscript operator for ndarray not implemented")) - } - - let index = if let Some(v) = generator.gen_expr(ctx, slice)? { - v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() - } else { - return Ok(None) - }; - let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; - ctx.builder.build_store(index_addr, index).unwrap(); - Ok(Some(v.data() .get( ctx, @@ -1718,18 +1753,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( } else { // Accessing an element from a multi-dimensional `ndarray` - if let ExprKind::Slice { .. } = &slice.node { - return Err(String::from("subscript operator for ndarray not implemented")) - } - - let index = if let Some(v) = generator.gen_expr(ctx, slice)? { - v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() - } else { - return Ok(None) - }; - let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; - ctx.builder.build_store(index_addr, index).unwrap(); - // Create a new array, remove the top dimension from the dimension-size-list, and copy the // elements over let subscripted_ndarray = generator.gen_var_alloc( diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 49fdd2093..371bdb480 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -79,6 +79,13 @@ def test_ndarray_copy(): output_ndarray_float_2(x) output_ndarray_float_2(y) +def test_ndarray_neg_idx(): + x = np_identity(2) + + for i in range(-1, -3, -1): + for j in range(-1, -3, -1): + output_float64(x[i][j]) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -639,6 +646,7 @@ def run() -> int32: test_ndarray_identity() test_ndarray_fill() test_ndarray_copy() + test_ndarray_neg_idx() test_ndarray_add() test_ndarray_add_broadcast() test_ndarray_add_broadcast_lhs_scalar()