forked from M-Labs/nac3
core/expr: Add support for multi-dim slicing of NDArrays
This commit is contained in:
parent
c35ad06949
commit
ed79d5bb9e
|
@ -1667,6 +1667,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
slice: &Expr<Option<Type>>,
|
slice: &Expr<Option<Type>>,
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
||||||
|
@ -1712,32 +1713,11 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
slice.location,
|
slice.location,
|
||||||
);
|
);
|
||||||
|
|
||||||
if let ExprKind::Slice { lower, upper, step } = &slice.node {
|
// Normalizes a possibly-negative index to its corresponding positive index
|
||||||
let dim0_sz = unsafe {
|
let normalize_index = |generator: &mut G,
|
||||||
v.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
};
|
index: IntValue<'ctx>,
|
||||||
|
dim: u64| {
|
||||||
let Some((start, stop, step)) = handle_slice_indices(
|
|
||||||
lower,
|
|
||||||
upper,
|
|
||||||
step,
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
dim0_sz,
|
|
||||||
)? else { return Ok(None) };
|
|
||||||
|
|
||||||
return Ok(Some(numpy::ndarray_sliced_copy(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ty,
|
|
||||||
v,
|
|
||||||
&[(start, stop, step)],
|
|
||||||
)?.as_ptr_value().into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
|
|
||||||
let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
|
|
||||||
|
|
||||||
gen_if_else_expr_callback(
|
gen_if_else_expr_callback(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -1757,7 +1737,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
v.dim_sizes().get_typed_unchecked(
|
v.dim_sizes().get_typed_unchecked(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_zero(),
|
&llvm_usize.const_int(dim, true),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
@ -1770,97 +1750,194 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
|
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
|
||||||
},
|
},
|
||||||
)?.map(BasicValueEnum::into_int_value).unwrap()
|
).map(|v| v.map(BasicValueEnum::into_int_value))
|
||||||
} else {
|
|
||||||
return Ok(None)
|
|
||||||
};
|
};
|
||||||
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
|
||||||
ctx.builder.build_store(index_addr, index).unwrap();
|
|
||||||
|
|
||||||
if ndims.len() == 1 && ndims[0] == 1 {
|
// Converts a slice expression into a slice-range tuple
|
||||||
// Accessing an element from a 1-dimensional `ndarray`
|
let expr_to_slice = |generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
node: &ExprKind<Option<Type>>,
|
||||||
|
dim: u64| {
|
||||||
|
match node {
|
||||||
|
ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||||
|
let Some(index) = normalize_index(
|
||||||
|
generator, ctx, llvm_i32.const_int(*v as u64, true), dim,
|
||||||
|
)? else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Some(v.data()
|
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||||
.get(
|
}
|
||||||
|
|
||||||
|
ExprKind::Slice { lower, upper, step } => {
|
||||||
|
let dim_sz = unsafe {
|
||||||
|
v.dim_sizes()
|
||||||
|
.get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(dim, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
let Some(index) = generator.gen_expr(ctx, slice)? else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
let index = index
|
||||||
|
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
|
||||||
|
.into_int_value();
|
||||||
|
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(match &slice.node {
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
let slices = elts.iter().enumerate()
|
||||||
|
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
|
||||||
|
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
if slices.len() < elts.len() {
|
||||||
|
return Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
let slices = slices.into_iter()
|
||||||
|
.map(Option::unwrap)
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
numpy::ndarray_sliced_copy(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ty,
|
||||||
|
v,
|
||||||
|
&slices,
|
||||||
|
)?.as_ptr_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
ExprKind::Slice { .. } => {
|
||||||
|
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
|
||||||
|
numpy::ndarray_sliced_copy(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ty,
|
||||||
|
v,
|
||||||
|
&[slice],
|
||||||
|
)?.as_ptr_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
|
||||||
|
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||||
|
} else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
|
||||||
|
return Ok(None)
|
||||||
|
};
|
||||||
|
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
||||||
|
ctx.builder.build_store(index_addr, index).unwrap();
|
||||||
|
|
||||||
|
if ndims.len() == 1 && ndims[0] == 1 {
|
||||||
|
// Accessing an element from a 1-dimensional `ndarray`
|
||||||
|
|
||||||
|
return Ok(Some(v.data()
|
||||||
|
.get(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&ArraySliceValue::from_ptr_val(
|
||||||
|
index_addr,
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accessing an element from a multi-dimensional `ndarray`
|
||||||
|
|
||||||
|
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||||
|
// elements over
|
||||||
|
let subscripted_ndarray = generator.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_ndarray_t.into(),
|
||||||
|
None
|
||||||
|
)?;
|
||||||
|
let ndarray = NDArrayValue::from_ptr_val(
|
||||||
|
subscripted_ndarray,
|
||||||
|
llvm_usize,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
|
||||||
|
let num_dims = v.load_ndims(ctx);
|
||||||
|
ndarray.store_ndims(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
||||||
|
|
||||||
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
|
let v_dims_src_ptr = unsafe {
|
||||||
|
v.dim_sizes().ptr_offset_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ndarray.dim_sizes().base_ptr(ctx, generator),
|
||||||
|
v_dims_src_ptr,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
||||||
|
.map(Into::into)
|
||||||
|
.unwrap(),
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
|
);
|
||||||
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
let v_data_src_ptr = v.data().ptr_offset(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
||||||
None,
|
None
|
||||||
)
|
);
|
||||||
.into()))
|
call_memcpy_generic(
|
||||||
} else {
|
|
||||||
// Accessing an element from a multi-dimensional `ndarray`
|
|
||||||
|
|
||||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
|
||||||
// elements over
|
|
||||||
let subscripted_ndarray = generator.gen_var_alloc(
|
|
||||||
ctx,
|
|
||||||
llvm_ndarray_t.into(),
|
|
||||||
None
|
|
||||||
)?;
|
|
||||||
let ndarray = NDArrayValue::from_ptr_val(
|
|
||||||
subscripted_ndarray,
|
|
||||||
llvm_usize,
|
|
||||||
None
|
|
||||||
);
|
|
||||||
|
|
||||||
let num_dims = v.load_ndims(ctx);
|
|
||||||
ndarray.store_ndims(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
|
||||||
let v_dims_src_ptr = unsafe {
|
|
||||||
v.dim_sizes().ptr_offset_unchecked(
|
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
ndarray.data().base_ptr(ctx, generator),
|
||||||
&llvm_usize.const_int(1, false),
|
v_data_src_ptr,
|
||||||
None,
|
ctx.builder
|
||||||
)
|
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
|
||||||
};
|
.map(Into::into)
|
||||||
call_memcpy_generic(
|
.unwrap(),
|
||||||
ctx,
|
llvm_i1.const_zero(),
|
||||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
);
|
||||||
v_dims_src_ptr,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
|
||||||
.map(Into::into)
|
|
||||||
.unwrap(),
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
ndarray.as_ptr_value().into()
|
||||||
generator,
|
}
|
||||||
ctx,
|
}))
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
|
||||||
(None, None),
|
|
||||||
);
|
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
|
||||||
|
|
||||||
let v_data_src_ptr = v.data().ptr_offset(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
|
||||||
None
|
|
||||||
);
|
|
||||||
call_memcpy_generic(
|
|
||||||
ctx,
|
|
||||||
ndarray.data().base_ptr(ctx, generator),
|
|
||||||
v_data_src_ptr,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
|
|
||||||
.map(Into::into)
|
|
||||||
.unwrap(),
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(Some(ndarray.as_ptr_value().into()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_expr`].
|
/// See [`CodeGenerator::gen_expr`].
|
||||||
|
|
|
@ -2,6 +2,7 @@ use std::collections::{HashMap, HashSet};
|
||||||
use std::convert::{From, TryInto};
|
use std::convert::{From, TryInto};
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
use std::{cell::RefCell, sync::Arc};
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
use std::ops::Not;
|
||||||
|
|
||||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
||||||
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
||||||
|
@ -554,7 +555,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
ExprKind::ListComp { .. }
|
ExprKind::ListComp { .. }
|
||||||
| ExprKind::Lambda { .. }
|
| ExprKind::Lambda { .. }
|
||||||
| ExprKind::Call { .. } => expr.custom, // already computed
|
| ExprKind::Call { .. } => expr.custom, // already computed
|
||||||
ExprKind::Slice { .. } => None, // we don't need it for slice
|
ExprKind::Slice { .. } => {
|
||||||
|
// slices aren't exactly ranges, but for our purposes this should suffice
|
||||||
|
Some(self.primitives.range)
|
||||||
|
}
|
||||||
_ => return report_error("not supported", expr.location),
|
_ => return report_error("not supported", expr.location),
|
||||||
};
|
};
|
||||||
Ok(ast::Expr { custom, location: expr.location, node: expr.node })
|
Ok(ast::Expr { custom, location: expr.location, node: expr.node })
|
||||||
|
@ -1642,6 +1646,30 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
if value.custom
|
||||||
|
.unwrap()
|
||||||
|
.obj_id(self.unifier)
|
||||||
|
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)
|
||||||
|
.not() {
|
||||||
|
return report_error("Tuple slices are only supported for ndarrays", slice.location)
|
||||||
|
}
|
||||||
|
|
||||||
|
for elt in elts {
|
||||||
|
if let ExprKind::Slice { lower, upper, step } = &elt.node {
|
||||||
|
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||||
|
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||||
|
let ndarray_ty = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||||
|
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
|
||||||
|
Ok(ndarray_ty)
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
||||||
|
|
Loading…
Reference in New Issue