forked from M-Labs/nac3
core/ndstrides: refactor numpy indexing
This commit is contained in:
parent
86b0d31290
commit
2ab7b299b8
|
@ -2,6 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
irrt::slice::{RustUserSlice, SliceIndex},
|
irrt::slice::{RustUserSlice, SliceIndex},
|
||||||
|
numpy_new::object::{NDArrayObject, ScalarOrNDArray},
|
||||||
structure::ndarray::NpArray,
|
structure::ndarray::NpArray,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -18,7 +19,6 @@ use crate::{
|
||||||
call_memcpy_generic,
|
call_memcpy_generic,
|
||||||
},
|
},
|
||||||
need_sret, numpy,
|
need_sret, numpy,
|
||||||
numpy_new::util::alloca_ndarray,
|
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
gen_var,
|
gen_var,
|
||||||
|
@ -35,7 +35,7 @@ use crate::{
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
types::{AnyType, BasicType, BasicTypeEnum},
|
types::{AnyType, BasicType, BasicTypeEnum},
|
||||||
values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::{chain, izip, Either, Itertools};
|
use itertools::{chain, izip, Either, Itertools};
|
||||||
|
@ -44,7 +44,7 @@ use nac3parser::ast::{
|
||||||
StrRef, Unaryop,
|
StrRef, Unaryop,
|
||||||
};
|
};
|
||||||
|
|
||||||
use ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex};
|
use ndarray::indexing::RustNDIndex;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
model::*,
|
model::*,
|
||||||
|
@ -2130,23 +2130,13 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates code for a subscript expression on an `ndarray`.
|
pub fn gen_ndarray_subscript_ndindexes<'ctx, G: CodeGenerator>(
|
||||||
///
|
|
||||||
/// * `elem_ty` - The `Type` of the `NDArray` elements.
|
|
||||||
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
|
||||||
/// * `src_ndarray` - The `NDArray` value.
|
|
||||||
/// * `subscript` - The subscript expression used to index into the `ndarray`.
|
|
||||||
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
|
||||||
ndims: Type,
|
|
||||||
src_ndarray: Ptr<'ctx, StructModel<NpArray>>,
|
|
||||||
subscript: &Expr<Option<Type>>,
|
subscript: &Expr<Option<Type>>,
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
|
||||||
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
|
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
|
||||||
let tyctx = generator.type_context(ctx.ctx);
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
let sizet_model = IntModel(SizeT);
|
|
||||||
let slice_index_model = IntModel(SliceIndex::default());
|
let slice_index_model = IntModel(SliceIndex::default());
|
||||||
|
|
||||||
// Annoying notes about `slice`
|
// Annoying notes about `slice`
|
||||||
|
@ -2215,66 +2205,23 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
};
|
};
|
||||||
rust_ndindexes.push(ndindex);
|
rust_ndindexes.push(ndindex);
|
||||||
}
|
}
|
||||||
|
Ok(rust_ndindexes)
|
||||||
// Extract the `ndims` from a `Type` to `i128`
|
|
||||||
// We *HAVE* to know this statically, this is used to determine
|
|
||||||
// whether this subscript expression returns a scalar or an ndarray
|
|
||||||
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
|
|
||||||
else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
assert_eq!(ndims_values.len(), 1);
|
|
||||||
let src_ndims = i128::try_from(ndims_values[0].clone()).unwrap();
|
|
||||||
|
|
||||||
// Check for "too many indices for array: array is ..." error
|
|
||||||
if src_ndims < rust_ndindexes.len() as i128 {
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.ctx.bool_type().const_int(1, false),
|
|
||||||
"0:IndexError",
|
|
||||||
"too many indices for array: array is {0}-dimensional, but {1} were indexed",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let dst_ndims = RustNDIndex::deduce_ndims_after_slicing(&rust_ndindexes, src_ndims as i32);
|
/// Generates code for a subscript expression on an `ndarray`.
|
||||||
let dst_ndarray = alloca_ndarray(
|
///
|
||||||
generator,
|
/// * `elem_ty` - The `Type` of the `NDArray` elements.
|
||||||
ctx,
|
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||||
sizet_model.constant(tyctx, ctx.ctx, dst_ndims as u64),
|
/// * `src_ndarray` - The `NDArray` value.
|
||||||
"subndarray",
|
/// * `subscript` - The subscript expression used to index into the `ndarray`.
|
||||||
);
|
pub fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
// Prepare the subscripts
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
let (num_ndindexes, ndindexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, &rust_ndindexes);
|
ndarray: NDArrayObject<'ctx>,
|
||||||
|
subscript: &Expr<Option<Type>>,
|
||||||
// NOTE: IRRT does check for indexing errors
|
) -> Result<ScalarOrNDArray<'ctx>, String> {
|
||||||
call_nac3_ndarray_index(generator, ctx, num_ndindexes, ndindexes, src_ndarray, dst_ndarray);
|
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, subscript)?;
|
||||||
|
Ok(ndarray.index(generator, ctx, &indexes, "subndarray"))
|
||||||
// ...and return the result, with two cases
|
|
||||||
let result_llvm_value: BasicValueEnum<'_> = if dst_ndims == 0 {
|
|
||||||
// 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return the element
|
|
||||||
|
|
||||||
let pelement = dst_ndarray.gep(ctx, |f| f.data).load(tyctx, ctx, "pelement"); // `*data` points to the first element by definition
|
|
||||||
|
|
||||||
// Cast the opaque `pelement` ptr to `elem_ty`
|
|
||||||
let elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
let pelement = ctx
|
|
||||||
.builder
|
|
||||||
.build_pointer_cast(
|
|
||||||
pelement.value,
|
|
||||||
elem_ty.ptr_type(AddressSpace::default()),
|
|
||||||
"pelement_casted",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
ctx.builder.build_load(pelement, "element").unwrap().as_basic_value_enum()
|
|
||||||
} else {
|
|
||||||
// 2) ndims > 0 (other cases), return subndarray
|
|
||||||
dst_ndarray.value.as_basic_value_enum()
|
|
||||||
};
|
|
||||||
Ok(Some(ValueEnum::Dynamic(result_llvm_value)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_expr`].
|
/// See [`CodeGenerator::gen_expr`].
|
||||||
|
@ -2920,7 +2867,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
let tyctx = generator.type_context(ctx.ctx);
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
let pndarray_model = PtrModel(StructModel(NpArray));
|
let pndarray_model = PtrModel(StructModel(NpArray));
|
||||||
|
|
||||||
let (dtype, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
let (&dtype, &ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||||
|
|
||||||
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
|
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
|
@ -2929,10 +2876,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
let ndarray =
|
let ndarray =
|
||||||
ndarray.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
ndarray.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
||||||
let ndarray = pndarray_model.check_value(tyctx, ctx.ctx, ndarray).unwrap();
|
let ndarray = pndarray_model.check_value(tyctx, ctx.ctx, ndarray).unwrap();
|
||||||
|
let ndarray = NDArrayObject { dtype, ndims, instance: ndarray };
|
||||||
|
|
||||||
return gen_ndarray_subscript_expr(
|
let result = gen_ndarray_subscript_expr(generator, ctx, ndarray, slice)?;
|
||||||
generator, ctx, *dtype, *ndims, ndarray, slice,
|
return Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())));
|
||||||
);
|
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let index: u32 =
|
let index: u32 =
|
||||||
|
|
|
@ -76,7 +76,7 @@ impl<'ctx> RustNDIndex<'ctx> {
|
||||||
dst_ndindex_ptr.gep(ctx, |f| f.data).store(ctx, data);
|
dst_ndindex_ptr.gep(ctx, |f| f.data).store(ctx, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allocate an array of `NDIndex`es onto the stack and return its stack pointer.
|
/// Allocate an array of `NDIndex`es on the stack and return its stack pointer.
|
||||||
pub fn alloca_ndindexes(
|
pub fn alloca_ndindexes(
|
||||||
tyctx: TypeContext<'ctx>,
|
tyctx: TypeContext<'ctx>,
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
@ -97,10 +97,10 @@ impl<'ctx> RustNDIndex<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn deduce_ndims_after_slicing(slices: &[RustNDIndex], original_ndims: i32) -> i32 {
|
pub fn deduce_ndims_after_indexing(indices: &[RustNDIndex], original_ndims: u64) -> u64 {
|
||||||
let mut final_ndims: i32 = original_ndims;
|
let mut final_ndims = original_ndims;
|
||||||
for slice in slices {
|
for index in indices {
|
||||||
match slice {
|
match index {
|
||||||
RustNDIndex::SingleElement(_) => {
|
RustNDIndex::SingleElement(_) => {
|
||||||
final_ndims -= 1;
|
final_ndims -= 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
irrt::ndarray::indexing::{call_nac3_ndarray_index, RustNDIndex},
|
||||||
|
model::*,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::{Type, Unifier},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
object::{NDArrayObject, ScalarObject, ScalarOrNDArray},
|
||||||
|
util::{create_ndims, extract_ndims},
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
pub fn deduce_ndims_after_indexing_with(
|
||||||
|
&self,
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
indexes: &[RustNDIndex<'ctx>],
|
||||||
|
) -> Type {
|
||||||
|
let ndims = extract_ndims(unifier, self.ndims);
|
||||||
|
let new_ndims = RustNDIndex::deduce_ndims_after_indexing(indexes, ndims);
|
||||||
|
create_ndims(unifier, new_ndims)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn index_always_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
indexes: &[RustNDIndex<'ctx>],
|
||||||
|
name: &str,
|
||||||
|
) -> Self {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
|
||||||
|
let dst_ndims = self.deduce_ndims_after_indexing_with(&mut ctx.unifier, indexes);
|
||||||
|
let dst_ndarray = NDArrayObject::alloca(generator, ctx, dst_ndims, self.dtype, name);
|
||||||
|
|
||||||
|
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(tyctx, ctx, indexes);
|
||||||
|
call_nac3_ndarray_index(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
num_indexes,
|
||||||
|
indexes,
|
||||||
|
self.instance,
|
||||||
|
dst_ndarray.instance,
|
||||||
|
);
|
||||||
|
|
||||||
|
dst_ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
indexes: &[RustNDIndex<'ctx>],
|
||||||
|
name: &str,
|
||||||
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
|
let tyctx = generator.type_context(ctx.ctx);
|
||||||
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
|
let subndarray = self.index_always_ndarray(generator, ctx, indexes, name);
|
||||||
|
if subndarray.is_unsized(&ctx.unifier) {
|
||||||
|
// TODO: This actually never fails, don't use the `checked_` version.
|
||||||
|
let value = subndarray.checked_get_nth_element(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
sizet_model.const_0(tyctx, ctx.ctx),
|
||||||
|
name,
|
||||||
|
);
|
||||||
|
ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value })
|
||||||
|
} else {
|
||||||
|
ScalarOrNDArray::NDArray(subndarray)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod broadcast;
|
pub mod broadcast;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
|
pub mod indexing;
|
||||||
pub mod object;
|
pub mod object;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
pub mod view;
|
pub mod view;
|
||||||
|
|
|
@ -29,6 +29,7 @@ pub enum ScalarOrNDArray<'ctx> {
|
||||||
|
|
||||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||||
|
#[must_use]
|
||||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||||
match self {
|
match self {
|
||||||
ScalarOrNDArray::Scalar(scalar) => scalar.value,
|
ScalarOrNDArray::Scalar(scalar) => scalar.value,
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
use inkwell::types::BasicType;
|
use inkwell::{
|
||||||
|
types::BasicType,
|
||||||
|
values::{BasicValueEnum, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
use util::gen_model_memcpy;
|
use util::gen_model_memcpy;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -249,6 +253,37 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
|
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn checked_get_nth_pelement<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
i: Int<'ctx, SizeT>,
|
||||||
|
name: &str,
|
||||||
|
) -> PointerValue<'ctx> {
|
||||||
|
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
|
||||||
|
|
||||||
|
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, i);
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn checked_get_nth_element<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
i: Int<'ctx, SizeT>,
|
||||||
|
name: &str,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let pelement = self.checked_get_nth_pelement(generator, ctx, i, "pelement");
|
||||||
|
ctx.builder.build_load(pelement, name).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_unsized(&self, unifier: &Unifier) -> bool {
|
||||||
|
extract_ndims(unifier, self.ndims) == 0
|
||||||
|
}
|
||||||
|
|
||||||
pub fn size<G: CodeGenerator + ?Sized>(
|
pub fn size<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
|
Loading…
Reference in New Issue