core/ndstrides: refactor numpy indexing
This commit is contained in:
parent
86b0d31290
commit
2ab7b299b8
@ -2,6 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||
|
||||
use super::{
|
||||
irrt::slice::{RustUserSlice, SliceIndex},
|
||||
numpy_new::object::{NDArrayObject, ScalarOrNDArray},
|
||||
structure::ndarray::NpArray,
|
||||
};
|
||||
use crate::{
|
||||
@ -18,7 +19,6 @@ use crate::{
|
||||
call_memcpy_generic,
|
||||
},
|
||||
need_sret, numpy,
|
||||
numpy_new::util::alloca_ndarray,
|
||||
stmt::{
|
||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||
gen_var,
|
||||
@ -35,7 +35,7 @@ use crate::{
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{AnyType, BasicType, BasicTypeEnum},
|
||||
values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
||||
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::{chain, izip, Either, Itertools};
|
||||
@ -44,7 +44,7 @@ use nac3parser::ast::{
|
||||
StrRef, Unaryop,
|
||||
};
|
||||
|
||||
use ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex};
|
||||
use ndarray::indexing::RustNDIndex;
|
||||
|
||||
use super::{
|
||||
model::*,
|
||||
@ -2130,23 +2130,13 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates code for a subscript expression on an `ndarray`.
|
||||
///
|
||||
/// * `elem_ty` - The `Type` of the `NDArray` elements.
|
||||
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||
/// * `src_ndarray` - The `NDArray` value.
|
||||
/// * `subscript` - The subscript expression used to index into the `ndarray`.
|
||||
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
pub fn gen_ndarray_subscript_ndindexes<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
ndims: Type,
|
||||
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
|
||||
subscript: &Expr<Option<Type>>,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
|
||||
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let slice_index_model = IntModel(SliceIndex::default());
|
||||
|
||||
// Annoying notes about `slice`
|
||||
@ -2215,66 +2205,23 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
};
|
||||
rust_ndindexes.push(ndindex);
|
||||
}
|
||||
Ok(rust_ndindexes)
|
||||
}
|
||||
|
||||
// Extract the `ndims` from a `Type` to `i128`
|
||||
// We *HAVE* to know this statically, this is used to determine
|
||||
// whether this subscript expression returns a scalar or an ndarray
|
||||
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(ndims_values.len(), 1);
|
||||
let src_ndims = i128::try_from(ndims_values[0].clone()).unwrap();
|
||||
|
||||
// Check for "too many indices for array: array is ..." error
|
||||
if src_ndims < rust_ndindexes.len() as i128 {
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.ctx.bool_type().const_int(1, false),
|
||||
"0:IndexError",
|
||||
"too many indices for array: array is {0}-dimensional, but {1} were indexed",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let dst_ndims = RustNDIndex::deduce_ndims_after_slicing(&rust_ndindexes, src_ndims as i32);
|
||||
let dst_ndarray = alloca_ndarray(
|
||||
generator,
|
||||
ctx,
|
||||
sizet_model.constant(tyctx, ctx.ctx, dst_ndims as u64),
|
||||
"subndarray",
|
||||
);
|
||||
|
||||
// Prepare the subscripts
|
||||
let (num_ndindexes, ndindexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, &rust_ndindexes);
|
||||
|
||||
// NOTE: IRRT does check for indexing errors
|
||||
call_nac3_ndarray_index(generator, ctx, num_ndindexes, ndindexes, src_ndarray, dst_ndarray);
|
||||
|
||||
// ...and return the result, with two cases
|
||||
let result_llvm_value: BasicValueEnum<'_> = if dst_ndims == 0 {
|
||||
// 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return the element
|
||||
|
||||
let pelement = dst_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "pelement"); // `*data` points to the first element by definition
|
||||
|
||||
// Cast the opaque `pelement` ptr to `elem_ty`
|
||||
let elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let pelement = ctx
|
||||
.builder
|
||||
.build_pointer_cast(
|
||||
pelement.value,
|
||||
elem_ty.ptr_type(AddressSpace::default()),
|
||||
"pelement_casted",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_load(pelement, "element").unwrap().as_basic_value_enum()
|
||||
} else {
|
||||
// 2) ndims > 0 (other cases), return subndarray
|
||||
dst_ndarray.value.as_basic_value_enum()
|
||||
};
|
||||
Ok(Some(ValueEnum::Dynamic(result_llvm_value)))
|
||||
/// Generates code for a subscript expression on an `ndarray`.
|
||||
///
|
||||
/// * `elem_ty` - The `Type` of the `NDArray` elements.
|
||||
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||
/// * `src_ndarray` - The `NDArray` value.
|
||||
/// * `subscript` - The subscript expression used to index into the `ndarray`.
|
||||
pub fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayObject<'ctx>,
|
||||
subscript: &Expr<Option<Type>>,
|
||||
) -> Result<ScalarOrNDArray<'ctx>, String> {
|
||||
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, subscript)?;
|
||||
Ok(ndarray.index(generator, ctx, &indexes, "subndarray"))
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_expr`].
|
||||
@ -2920,7 +2867,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
let pndarray_model = PtrModel(StructModel(NpArray));
|
||||
|
||||
let (dtype, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||
let (&dtype, &ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||
|
||||
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
|
||||
return Ok(None);
|
||||
@ -2929,10 +2876,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
let ndarray =
|
||||
ndarray.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
||||
let ndarray = pndarray_model.check_value(tyctx, ctx.ctx, ndarray).unwrap();
|
||||
let ndarray = NDArrayObject { dtype, ndims, instance: ndarray };
|
||||
|
||||
return gen_ndarray_subscript_expr(
|
||||
generator, ctx, *dtype, *ndims, ndarray, slice,
|
||||
);
|
||||
let result = gen_ndarray_subscript_expr(generator, ctx, ndarray, slice)?;
|
||||
return Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())));
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let index: u32 =
|
||||
|
@ -76,7 +76,7 @@ impl<'ctx> RustNDIndex<'ctx> {
|
||||
dst_ndindex_ptr.gep(ctx, |f| f.data).store(ctx, data);
|
||||
}
|
||||
|
||||
/// Allocate an array of `NDIndex`es onto the stack and return its stack pointer.
|
||||
/// Allocate an array of `NDIndex`es on the stack and return its stack pointer.
|
||||
pub fn alloca_ndindexes(
|
||||
tyctx: TypeContext<'ctx>,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
@ -97,10 +97,10 @@ impl<'ctx> RustNDIndex<'ctx> {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn deduce_ndims_after_slicing(slices: &[RustNDIndex], original_ndims: i32) -> i32 {
|
||||
let mut final_ndims: i32 = original_ndims;
|
||||
for slice in slices {
|
||||
match slice {
|
||||
pub fn deduce_ndims_after_indexing(indices: &[RustNDIndex], original_ndims: u64) -> u64 {
|
||||
let mut final_ndims = original_ndims;
|
||||
for index in indices {
|
||||
match index {
|
||||
RustNDIndex::SingleElement(_) => {
|
||||
final_ndims -= 1;
|
||||
}
|
||||
|
76
nac3core/src/codegen/numpy_new/indexing.rs
Normal file
76
nac3core/src/codegen/numpy_new/indexing.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex},
|
||||
model::*,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::{Type, Unifier},
|
||||
};
|
||||
|
||||
use super::{
|
||||
object::{NDArrayObject, ScalarObject, ScalarOrNDArray},
|
||||
util::{create_ndims, extract_ndims},
|
||||
};
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
pub fn deduce_ndims_after_indexing_with(
|
||||
&self,
|
||||
unifier: &mut Unifier,
|
||||
indexes: &[RustNDIndex<'ctx>],
|
||||
) -> Type {
|
||||
let ndims = extract_ndims(unifier, self.ndims);
|
||||
let new_ndims = RustNDIndex::deduce_ndims_after_indexing(indexes, ndims);
|
||||
create_ndims(unifier, new_ndims)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn index_always_ndarray<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
indexes: &[RustNDIndex<'ctx>],
|
||||
name: &str,
|
||||
) -> Self {
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
|
||||
let dst_ndims = self.deduce_ndims_after_indexing_with(&mut ctx.unifier, indexes);
|
||||
let dst_ndarray = NDArrayObject::alloca(generator, ctx, dst_ndims, self.dtype, name);
|
||||
|
||||
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, indexes);
|
||||
call_nac3_ndarray_index(
|
||||
generator,
|
||||
ctx,
|
||||
num_indexes,
|
||||
indexes,
|
||||
self.instance,
|
||||
dst_ndarray.instance,
|
||||
);
|
||||
|
||||
dst_ndarray
|
||||
}
|
||||
|
||||
pub fn index<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
indexes: &[RustNDIndex<'ctx>],
|
||||
name: &str,
|
||||
) -> ScalarOrNDArray<'ctx> {
|
||||
let tyctx = generator.type_context(ctx.ctx);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
let subndarray = self.index_always_ndarray(generator, ctx, indexes, name);
|
||||
if subndarray.is_unsized(&ctx.unifier) {
|
||||
// TODO: This actually never fails, don't use the `checked_` version.
|
||||
let value = subndarray.checked_get_nth_element(
|
||||
generator,
|
||||
ctx,
|
||||
sizet_model.const_0(tyctx, ctx.ctx),
|
||||
name,
|
||||
);
|
||||
ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value })
|
||||
} else {
|
||||
ScalarOrNDArray::NDArray(subndarray)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
pub mod broadcast;
|
||||
pub mod factory;
|
||||
pub mod indexing;
|
||||
pub mod object;
|
||||
pub mod util;
|
||||
pub mod view;
|
||||
|
@ -29,6 +29,7 @@ pub enum ScalarOrNDArray<'ctx> {
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||
#[must_use]
|
||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.value,
|
||||
|
@ -1,4 +1,8 @@
|
||||
use inkwell::types::BasicType;
|
||||
use inkwell::{
|
||||
types::BasicType,
|
||||
values::{BasicValueEnum, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use util::gen_model_memcpy;
|
||||
|
||||
use crate::{
|
||||
@ -249,6 +253,37 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
|
||||
}
|
||||
|
||||
pub fn checked_get_nth_pelement<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
i: Int<'ctx, SizeT>,
|
||||
name: &str,
|
||||
) -> PointerValue<'ctx> {
|
||||
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
|
||||
|
||||
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, i);
|
||||
ctx.builder
|
||||
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn checked_get_nth_element<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
i: Int<'ctx, SizeT>,
|
||||
name: &str,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let pelement = self.checked_get_nth_pelement(generator, ctx, i, "pelement");
|
||||
ctx.builder.build_load(pelement, name).unwrap()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_unsized(&self, unifier: &Unifier) -> bool {
|
||||
extract_ndims(unifier, self.ndims) == 0
|
||||
}
|
||||
|
||||
pub fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
|
Loading…
Reference in New Issue
Block a user