core/expr: Add support for multi-dim slicing of NDArrays

This commit is contained in:
David Mak 2024-05-30 16:08:15 +08:00
parent c35ad06949
commit ed79d5bb9e
2 changed files with 216 additions and 111 deletions

View File

@ -1667,6 +1667,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
slice: &Expr<Option<Type>>, slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
@ -1712,32 +1713,11 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
slice.location, slice.location,
); );
if let ExprKind::Slice { lower, upper, step } = &slice.node { // Normalizes a possibly-negative index to its corresponding positive index
let dim0_sz = unsafe { let normalize_index = |generator: &mut G,
v.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) ctx: &mut CodeGenContext<'ctx, '_>,
}; index: IntValue<'ctx>,
dim: u64| {
let Some((start, stop, step)) = handle_slice_indices(
lower,
upper,
step,
ctx,
generator,
dim0_sz,
)? else { return Ok(None) };
return Ok(Some(numpy::ndarray_sliced_copy(
generator,
ctx,
ty,
v,
&[(start, stop, step)],
)?.as_ptr_value().into()))
}
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
gen_if_else_expr_callback( gen_if_else_expr_callback(
generator, generator,
ctx, ctx,
@ -1757,7 +1737,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
v.dim_sizes().get_typed_unchecked( v.dim_sizes().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_zero(), &llvm_usize.const_int(dim, true),
None, None,
) )
}; };
@ -1770,25 +1750,121 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
}, },
)?.map(BasicValueEnum::into_int_value).unwrap() ).map(|v| v.map(BasicValueEnum::into_int_value))
};
// Converts a slice expression into a slice-range tuple
let expr_to_slice = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
node: &ExprKind<Option<Type>>,
dim: u64| {
match node {
ExprKind::Constant { value: Constant::Int(v), .. } => {
let Some(index) = normalize_index(
generator, ctx, llvm_i32.const_int(*v as u64, true), dim,
)? else {
return Ok(None)
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
ExprKind::Slice { lower, upper, step } => {
let dim_sz = unsafe {
v.dim_sizes()
.get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, false),
None,
)
};
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
}
_ => {
let Some(index) = generator.gen_expr(ctx, slice)? else {
return Ok(None)
};
let index = index
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
return Ok(None)
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
}
};
Ok(Some(match &slice.node {
ExprKind::Tuple { elts, .. } => {
let slices = elts.iter().enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None)
}
let slices = slices.into_iter()
.map(Option::unwrap)
.collect_vec();
numpy::ndarray_sliced_copy(
generator,
ctx,
ty,
v,
&slices,
)?.as_ptr_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None)
};
numpy::ndarray_sliced_copy(
generator,
ctx,
ty,
v,
&[slice],
)?.as_ptr_value().into()
}
_ => {
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else { } else {
return Ok(None) return Ok(None)
}; };
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None)
};
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?; let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap(); ctx.builder.build_store(index_addr, index).unwrap();
if ndims.len() == 1 && ndims[0] == 1 { if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray` // Accessing an element from a 1-dimensional `ndarray`
Ok(Some(v.data() return Ok(Some(v.data()
.get( .get(
ctx, ctx,
generator, generator,
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None), &ArraySliceValue::from_ptr_val(
index_addr,
llvm_usize.const_int(1, false),
None,
),
None, None,
) )
.into())) .into()))
} else { }
// Accessing an element from a multi-dimensional `ndarray` // Accessing an element from a multi-dimensional `ndarray`
// Create a new array, remove the top dimension from the dimension-size-list, and copy the // Create a new array, remove the top dimension from the dimension-size-list, and copy the
@ -1859,8 +1935,9 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
llvm_i1.const_zero(), llvm_i1.const_zero(),
); );
Ok(Some(ndarray.as_ptr_value().into())) ndarray.as_ptr_value().into()
} }
}))
} }
/// See [`CodeGenerator::gen_expr`]. /// See [`CodeGenerator::gen_expr`].

View File

@ -2,6 +2,7 @@ use std::collections::{HashMap, HashSet};
use std::convert::{From, TryInto}; use std::convert::{From, TryInto};
use std::iter::once; use std::iter::once;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use std::ops::Not;
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
@ -554,7 +555,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
ExprKind::ListComp { .. } ExprKind::ListComp { .. }
| ExprKind::Lambda { .. } | ExprKind::Lambda { .. }
| ExprKind::Call { .. } => expr.custom, // already computed | ExprKind::Call { .. } => expr.custom, // already computed
ExprKind::Slice { .. } => None, // we don't need it for slice ExprKind::Slice { .. } => {
// slices aren't exactly ranges, but for our purposes this should suffice
Some(self.primitives.range)
}
_ => return report_error("not supported", expr.location), _ => return report_error("not supported", expr.location),
}; };
Ok(ast::Expr { custom, location: expr.location, node: expr.node }) Ok(ast::Expr { custom, location: expr.location, node: expr.node })
@ -1642,6 +1646,30 @@ impl<'a> Inferencer<'a> {
} }
} }
} }
ExprKind::Tuple { elts, .. } => {
if value.custom
.unwrap()
.obj_id(self.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)
.not() {
return report_error("Tuple slices are only supported for ndarrays", slice.location)
}
for elt in elts {
if let ExprKind::Slice { lower, upper, step } = &elt.node {
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
}
} else {
self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?;
}
}
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndarray_ty = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
Ok(ndarray_ty)
}
_ => { _ => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location) return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)