core/expr: Implement negative indices for ndarray

This commit is contained in:
David Mak 2024-04-15 12:20:13 +08:00
parent f0715e2b6d
commit e0f440040c
2 changed files with 56 additions and 25 deletions

View File

@ -9,6 +9,7 @@ use crate::{
ListValue, ListValue,
NDArrayValue, NDArrayValue,
RangeValue, RangeValue,
TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeAccessor,
}, },
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
@ -18,7 +19,7 @@ use crate::{
irrt::*, irrt::*,
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
numpy, numpy,
stmt::{gen_raise, gen_var}, stmt::{gen_if_else_expr_callback, gen_raise, gen_var},
CodeGenContext, CodeGenTask, CodeGenContext, CodeGenTask,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
@ -1692,21 +1693,55 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
slice.location, slice.location,
); );
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
if let ExprKind::Slice { .. } = &slice.node { if let ExprKind::Slice { .. } = &slice.node {
return Err(String::from("subscript operator for ndarray not implemented")) return Err(String::from("subscript operator for ndarray not implemented"))
} }
let index = if let Some(v) = generator.gen_expr(ctx, slice)? { let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx.builder.build_int_compare(
IntPredicate::SGE,
index,
index.get_type().const_zero(),
"",
).unwrap())
},
|_, _| Ok(Some(index)),
|generator, ctx| {
let llvm_i32 = ctx.ctx.i32_type();
let len = unsafe {
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
llvm_usize.const_zero(),
None,
)
};
let index = ctx.builder.build_int_add(
len,
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
"",
).unwrap();
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
},
)?.map(BasicValueEnum::into_int_value).unwrap()
} else { } else {
return Ok(None) return Ok(None)
}; };
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap(); ctx.builder.build_store(index_addr, index).unwrap();
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
Ok(Some(v.data() Ok(Some(v.data()
.get( .get(
ctx, ctx,
@ -1718,18 +1753,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
} else { } else {
// Accessing an element from a multi-dimensional `ndarray` // Accessing an element from a multi-dimensional `ndarray`
if let ExprKind::Slice { .. } = &slice.node {
return Err(String::from("subscript operator for ndarray not implemented"))
}
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None)
};
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap();
// Create a new array, remove the top dimension from the dimension-size-list, and copy the // Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over // elements over
let subscripted_ndarray = generator.gen_var_alloc( let subscripted_ndarray = generator.gen_var_alloc(

View File

@ -79,6 +79,13 @@ def test_ndarray_copy():
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_float_2(y) output_ndarray_float_2(y)
def test_ndarray_neg_idx():
x = np_identity(2)
for i in range(-1, -3, -1):
for j in range(-1, -3, -1):
output_float64(x[i][j])
def test_ndarray_add(): def test_ndarray_add():
x = np_identity(2) x = np_identity(2)
y = x + np_ones([2, 2]) y = x + np_ones([2, 2])
@ -639,6 +646,7 @@ def run() -> int32:
test_ndarray_identity() test_ndarray_identity()
test_ndarray_fill() test_ndarray_fill()
test_ndarray_copy() test_ndarray_copy()
test_ndarray_neg_idx()
test_ndarray_add() test_ndarray_add()
test_ndarray_add_broadcast() test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar() test_ndarray_add_broadcast_lhs_scalar()