forked from M-Labs/nac3
core/expr: Implement negative indices for ndarray
This commit is contained in:
parent
f0715e2b6d
commit
e0f440040c
|
@ -9,6 +9,7 @@ use crate::{
|
|||
ListValue,
|
||||
NDArrayValue,
|
||||
RangeValue,
|
||||
TypedArrayLikeAccessor,
|
||||
UntypedArrayLikeAccessor,
|
||||
},
|
||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||
|
@ -18,7 +19,7 @@ use crate::{
|
|||
irrt::*,
|
||||
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
||||
numpy,
|
||||
stmt::{gen_raise, gen_var},
|
||||
stmt::{gen_if_else_expr_callback, gen_raise, gen_var},
|
||||
CodeGenContext, CodeGenTask,
|
||||
},
|
||||
symbol_resolver::{SymbolValue, ValueEnum},
|
||||
|
@ -1692,21 +1693,55 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
slice.location,
|
||||
);
|
||||
|
||||
if let ExprKind::Slice { .. } = &slice.node {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
|
||||
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
|
||||
let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
|
||||
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx.builder.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
index,
|
||||
index.get_type().const_zero(),
|
||||
"",
|
||||
).unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(index)),
|
||||
|generator, ctx| {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
let len = unsafe {
|
||||
v.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_zero(),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let index = ctx.builder.build_int_add(
|
||||
len,
|
||||
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
|
||||
"",
|
||||
).unwrap();
|
||||
|
||||
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
|
||||
},
|
||||
)?.map(BasicValueEnum::into_int_value).unwrap()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
||||
ctx.builder.build_store(index_addr, index).unwrap();
|
||||
|
||||
if ndims.len() == 1 && ndims[0] == 1 {
|
||||
// Accessing an element from a 1-dimensional `ndarray`
|
||||
|
||||
if let ExprKind::Slice { .. } = &slice.node {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
|
||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
||||
ctx.builder.build_store(index_addr, index).unwrap();
|
||||
|
||||
Ok(Some(v.data()
|
||||
.get(
|
||||
ctx,
|
||||
|
@ -1718,18 +1753,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
} else {
|
||||
// Accessing an element from a multi-dimensional `ndarray`
|
||||
|
||||
if let ExprKind::Slice { .. } = &slice.node {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
|
||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
||||
ctx.builder.build_store(index_addr, index).unwrap();
|
||||
|
||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||
// elements over
|
||||
let subscripted_ndarray = generator.gen_var_alloc(
|
||||
|
|
|
@ -79,6 +79,13 @@ def test_ndarray_copy():
|
|||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_neg_idx():
|
||||
x = np_identity(2)
|
||||
|
||||
for i in range(-1, -3, -1):
|
||||
for j in range(-1, -3, -1):
|
||||
output_float64(x[i][j])
|
||||
|
||||
def test_ndarray_add():
|
||||
x = np_identity(2)
|
||||
y = x + np_ones([2, 2])
|
||||
|
@ -639,6 +646,7 @@ def run() -> int32:
|
|||
test_ndarray_identity()
|
||||
test_ndarray_fill()
|
||||
test_ndarray_copy()
|
||||
test_ndarray_neg_idx()
|
||||
test_ndarray_add()
|
||||
test_ndarray_add_broadcast()
|
||||
test_ndarray_add_broadcast_lhs_scalar()
|
||||
|
|
Loading…
Reference in New Issue