forked from M-Labs/nac3
core/expr: Implement negative indices for ndarray
This commit is contained in:
parent
f0715e2b6d
commit
e0f440040c
|
@ -9,6 +9,7 @@ use crate::{
|
||||||
ListValue,
|
ListValue,
|
||||||
NDArrayValue,
|
NDArrayValue,
|
||||||
RangeValue,
|
RangeValue,
|
||||||
|
TypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeAccessor,
|
UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
|
@ -18,7 +19,7 @@ use crate::{
|
||||||
irrt::*,
|
irrt::*,
|
||||||
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
||||||
numpy,
|
numpy,
|
||||||
stmt::{gen_raise, gen_var},
|
stmt::{gen_if_else_expr_callback, gen_raise, gen_var},
|
||||||
CodeGenContext, CodeGenTask,
|
CodeGenContext, CodeGenTask,
|
||||||
},
|
},
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
|
@ -1692,21 +1693,55 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
slice.location,
|
slice.location,
|
||||||
);
|
);
|
||||||
|
|
||||||
if ndims.len() == 1 && ndims[0] == 1 {
|
|
||||||
// Accessing an element from a 1-dimensional `ndarray`
|
|
||||||
|
|
||||||
if let ExprKind::Slice { .. } = &slice.node {
|
if let ExprKind::Slice { .. } = &slice.node {
|
||||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||||
}
|
}
|
||||||
|
|
||||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
|
||||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
|
||||||
|
|
||||||
|
gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::SGE,
|
||||||
|
index,
|
||||||
|
index.get_type().const_zero(),
|
||||||
|
"",
|
||||||
|
).unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(index)),
|
||||||
|
|generator, ctx| {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
|
||||||
|
let len = unsafe {
|
||||||
|
v.dim_sizes().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let index = ctx.builder.build_int_add(
|
||||||
|
len,
|
||||||
|
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
|
||||||
|
"",
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
|
||||||
|
},
|
||||||
|
)?.map(BasicValueEnum::into_int_value).unwrap()
|
||||||
} else {
|
} else {
|
||||||
return Ok(None)
|
return Ok(None)
|
||||||
};
|
};
|
||||||
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
||||||
ctx.builder.build_store(index_addr, index).unwrap();
|
ctx.builder.build_store(index_addr, index).unwrap();
|
||||||
|
|
||||||
|
if ndims.len() == 1 && ndims[0] == 1 {
|
||||||
|
// Accessing an element from a 1-dimensional `ndarray`
|
||||||
|
|
||||||
Ok(Some(v.data()
|
Ok(Some(v.data()
|
||||||
.get(
|
.get(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -1718,18 +1753,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
} else {
|
} else {
|
||||||
// Accessing an element from a multi-dimensional `ndarray`
|
// Accessing an element from a multi-dimensional `ndarray`
|
||||||
|
|
||||||
if let ExprKind::Slice { .. } = &slice.node {
|
|
||||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
|
||||||
}
|
|
||||||
|
|
||||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
|
||||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
|
||||||
} else {
|
|
||||||
return Ok(None)
|
|
||||||
};
|
|
||||||
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
|
||||||
ctx.builder.build_store(index_addr, index).unwrap();
|
|
||||||
|
|
||||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||||
// elements over
|
// elements over
|
||||||
let subscripted_ndarray = generator.gen_var_alloc(
|
let subscripted_ndarray = generator.gen_var_alloc(
|
||||||
|
|
|
@ -79,6 +79,13 @@ def test_ndarray_copy():
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_float_2(y)
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_neg_idx():
|
||||||
|
x = np_identity(2)
|
||||||
|
|
||||||
|
for i in range(-1, -3, -1):
|
||||||
|
for j in range(-1, -3, -1):
|
||||||
|
output_float64(x[i][j])
|
||||||
|
|
||||||
def test_ndarray_add():
|
def test_ndarray_add():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = x + np_ones([2, 2])
|
y = x + np_ones([2, 2])
|
||||||
|
@ -639,6 +646,7 @@ def run() -> int32:
|
||||||
test_ndarray_identity()
|
test_ndarray_identity()
|
||||||
test_ndarray_fill()
|
test_ndarray_fill()
|
||||||
test_ndarray_copy()
|
test_ndarray_copy()
|
||||||
|
test_ndarray_neg_idx()
|
||||||
test_ndarray_add()
|
test_ndarray_add()
|
||||||
test_ndarray_add_broadcast()
|
test_ndarray_add_broadcast()
|
||||||
test_ndarray_add_broadcast_lhs_scalar()
|
test_ndarray_add_broadcast_lhs_scalar()
|
||||||
|
|
Loading…
Reference in New Issue