forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: refactor numpy indexing

This commit is contained in:
lyken 2024-07-30 17:52:28 +08:00
parent 86b0d31290
commit 2ab7b299b8
6 changed files with 144 additions and 84 deletions

View File

@ -2,6 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{
irrt::slice::{RustUserSlice, SliceIndex},
numpy_new::object::{NDArrayObject, ScalarOrNDArray},
structure::ndarray::NpArray,
};
use crate::{
@ -18,7 +19,6 @@ use crate::{
call_memcpy_generic,
},
need_sret, numpy,
numpy_new::util::alloca_ndarray,
stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var,
@ -35,7 +35,7 @@ use crate::{
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::{chain, izip, Either, Itertools};
@ -44,7 +44,7 @@ use nac3parser::ast::{
StrRef, Unaryop,
};
use ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex};
use ndarray::indexing::RustNDIndex;
use super::{
model::*,
@ -2130,23 +2130,13 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
)
}
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `elem_ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `src_ndarray` - The `NDArray` value.
/// * `subscript` - The subscript expression used to index into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
pub fn gen_ndarray_subscript_ndindexes<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
ndims: Type,
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
subscript: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> {
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let slice_index_model = IntModel(SliceIndex::default());
// Annoying notes about `slice`
@ -2215,66 +2205,23 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
};
rust_ndindexes.push(ndindex);
}
Ok(rust_ndindexes)
}
// Extract the `ndims` from a `Type` to `i128`
// We *HAVE* to know this statically, this is used to determine
// whether this subscript expression returns a scalar or an ndarray
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
else {
unreachable!()
};
assert_eq!(ndims_values.len(), 1);
let src_ndims = i128::try_from(ndims_values[0].clone()).unwrap();
// Check for "too many indices for array: array is ..." error
if src_ndims < rust_ndindexes.len() as i128 {
ctx.make_assert(
generator,
ctx.ctx.bool_type().const_int(1, false),
"0:IndexError",
"too many indices for array: array is {0}-dimensional, but {1} were indexed",
[None, None, None],
ctx.current_loc,
);
}
let dst_ndims = RustNDIndex::deduce_ndims_after_slicing(&rust_ndindexes, src_ndims as i32);
let dst_ndarray = alloca_ndarray(
generator,
ctx,
sizet_model.constant(tyctx, ctx.ctx, dst_ndims as u64),
"subndarray",
);
// Prepare the subscripts
let (num_ndindexes, ndindexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, &rust_ndindexes);
// NOTE: IRRT does check for indexing errors
call_nac3_ndarray_index(generator, ctx, num_ndindexes, ndindexes, src_ndarray, dst_ndarray);
// ...and return the result, with two cases
let result_llvm_value: BasicValueEnum<'_> = if dst_ndims == 0 {
// 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return the element
let pelement = dst_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "pelement"); // `*data` points to the first element by definition
// Cast the opaque `pelement` ptr to `elem_ty`
let elem_ty = ctx.get_llvm_type(generator, elem_ty);
let pelement = ctx
.builder
.build_pointer_cast(
pelement.value,
elem_ty.ptr_type(AddressSpace::default()),
"pelement_casted",
)
.unwrap();
ctx.builder.build_load(pelement, "element").unwrap().as_basic_value_enum()
} else {
// 2) ndims > 0 (other cases), return subndarray
dst_ndarray.value.as_basic_value_enum()
};
Ok(Some(ValueEnum::Dynamic(result_llvm_value)))
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `elem_ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `src_ndarray` - The `NDArray` value.
/// * `subscript` - The subscript expression used to index into the `ndarray`.
pub fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
subscript: &Expr<Option<Type>>,
) -> Result<ScalarOrNDArray<'ctx>, String> {
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, subscript)?;
Ok(ndarray.index(generator, ctx, &indexes, "subndarray"))
}
/// See [`CodeGenerator::gen_expr`].
@ -2920,7 +2867,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let tyctx = generator.type_context(ctx.ctx);
let pndarray_model = PtrModel(StructModel(NpArray));
let (dtype, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
let (&dtype, &ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
return Ok(None);
@ -2929,10 +2876,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let ndarray =
ndarray.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
let ndarray = pndarray_model.check_value(tyctx, ctx.ctx, ndarray).unwrap();
let ndarray = NDArrayObject { dtype, ndims, instance: ndarray };
return gen_ndarray_subscript_expr(
generator, ctx, *dtype, *ndims, ndarray, slice,
);
let result = gen_ndarray_subscript_expr(generator, ctx, ndarray, slice)?;
return Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())));
}
TypeEnum::TTuple { .. } => {
let index: u32 =

View File

@ -76,7 +76,7 @@ impl<'ctx> RustNDIndex<'ctx> {
dst_ndindex_ptr.gep(ctx, |f| f.data).store(ctx, data);
}
/// Allocate an array of `NDIndex`es onto the stack and return its stack pointer.
/// Allocate an array of `NDIndex`es on the stack and return its stack pointer.
pub fn alloca_ndindexes(
tyctx: TypeContext<'ctx>,
ctx: &CodeGenContext<'ctx, '_>,
@ -97,10 +97,10 @@ impl<'ctx> RustNDIndex<'ctx> {
}
#[must_use]
pub fn deduce_ndims_after_slicing(slices: &[RustNDIndex], original_ndims: i32) -> i32 {
let mut final_ndims: i32 = original_ndims;
for slice in slices {
match slice {
pub fn deduce_ndims_after_indexing(indices: &[RustNDIndex], original_ndims: u64) -> u64 {
let mut final_ndims = original_ndims;
for index in indices {
match index {
RustNDIndex::SingleElement(_) => {
final_ndims -= 1;
}

View File

@ -0,0 +1,76 @@
use crate::{
codegen::{
irrt::ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex},
model::*,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, Unifier},
};
use super::{
object::{NDArrayObject, ScalarObject, ScalarOrNDArray},
util::{create_ndims, extract_ndims},
};
impl<'ctx> NDArrayObject<'ctx> {
pub fn deduce_ndims_after_indexing_with(
&self,
unifier: &mut Unifier,
indexes: &[RustNDIndex<'ctx>],
) -> Type {
let ndims = extract_ndims(unifier, self.ndims);
let new_ndims = RustNDIndex::deduce_ndims_after_indexing(indexes, ndims);
create_ndims(unifier, new_ndims)
}
#[must_use]
pub fn index_always_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indexes: &[RustNDIndex<'ctx>],
name: &str,
) -> Self {
let tyctx = generator.type_context(ctx.ctx);
let dst_ndims = self.deduce_ndims_after_indexing_with(&mut ctx.unifier, indexes);
let dst_ndarray = NDArrayObject::alloca(generator, ctx, dst_ndims, self.dtype, name);
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, indexes);
call_nac3_ndarray_index(
generator,
ctx,
num_indexes,
indexes,
self.instance,
dst_ndarray.instance,
);
dst_ndarray
}
pub fn index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indexes: &[RustNDIndex<'ctx>],
name: &str,
) -> ScalarOrNDArray<'ctx> {
let tyctx = generator.type_context(ctx.ctx);
let sizet_model = IntModel(SizeT);
let subndarray = self.index_always_ndarray(generator, ctx, indexes, name);
if subndarray.is_unsized(&ctx.unifier) {
// TODO: This actually never fails, don't use the `checked_` version.
let value = subndarray.checked_get_nth_element(
generator,
ctx,
sizet_model.const_0(tyctx, ctx.ctx),
name,
);
ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value })
} else {
ScalarOrNDArray::NDArray(subndarray)
}
}
}

View File

@ -1,5 +1,6 @@
pub mod broadcast;
pub mod factory;
pub mod indexing;
pub mod object;
pub mod util;
pub mod view;

View File

@ -29,6 +29,7 @@ pub enum ScalarOrNDArray<'ctx> {
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value,

View File

@ -1,4 +1,8 @@
use inkwell::types::BasicType;
use inkwell::{
types::BasicType,
values::{BasicValueEnum, PointerValue},
AddressSpace,
};
use util::gen_model_memcpy;
use crate::{
@ -249,6 +253,37 @@ impl<'ctx> NDArrayObject<'ctx> {
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
}
pub fn checked_get_nth_pelement<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
i: Int<'ctx, SizeT>,
name: &str,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, i);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
.unwrap()
}
pub fn checked_get_nth_element<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
i: Int<'ctx, SizeT>,
name: &str,
) -> BasicValueEnum<'ctx> {
let pelement = self.checked_get_nth_pelement(generator, ctx, i, "pelement");
ctx.builder.build_load(pelement, name).unwrap()
}
#[must_use]
pub fn is_unsized(&self, unifier: &Unifier) -> bool {
extract_ndims(unifier, self.ndims) == 0
}
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,