forked from M-Labs/nac3
core: Fix codegen for tuple-index into ndarray
This commit is contained in:
parent
635c944c90
commit
0452e6de78
|
@ -3,8 +3,8 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{
|
classes::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
|
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
|
||||||
RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type,
|
gen_in_range_check, get_llvm_abi_type, get_llvm_type,
|
||||||
|
@ -1741,22 +1741,37 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
let ndims = values
|
let ndims = values
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ndim| match *ndim {
|
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
|
||||||
SymbolValue::U64(v) => Ok(v),
|
.collect::<Result<Vec<_>, _>>()
|
||||||
SymbolValue::U32(v) => Ok(u64::from(v)),
|
.map_err(|val| {
|
||||||
SymbolValue::I32(v) => u64::try_from(v)
|
format!(
|
||||||
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")),
|
"Expected non-negative literal for ndarray.ndims, got {}",
|
||||||
SymbolValue::I64(v) => u64::try_from(v)
|
i128::try_from(val).unwrap()
|
||||||
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")),
|
)
|
||||||
_ => unreachable!(),
|
})?;
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
|
||||||
|
|
||||||
assert!(!ndims.is_empty());
|
assert!(!ndims.is_empty());
|
||||||
|
|
||||||
let ndarray_ndims_ty = ctx
|
// The number of dimensions subscripted by the index expression.
|
||||||
.unifier
|
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
|
||||||
.get_fresh_literal(ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), None);
|
// dimension will remove a dimension.
|
||||||
|
let subscripted_dims = match &slice.node {
|
||||||
|
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
|
||||||
|
if let ExprKind::Slice { .. } = &value_subexpr.node {
|
||||||
|
acc
|
||||||
|
} else {
|
||||||
|
acc + 1
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
|
ExprKind::Slice { .. } => 0,
|
||||||
|
_ => 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
|
||||||
|
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
let ndarray_ty =
|
let ndarray_ty =
|
||||||
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
|
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
|
||||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||||
|
@ -1859,123 +1874,165 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(match &slice.node {
|
let make_indices_arr = |generator: &mut G,
|
||||||
ExprKind::Tuple { elts, .. } => {
|
ctx: &mut CodeGenContext<'ctx, '_>|
|
||||||
let slices = elts
|
-> Result<_, String> {
|
||||||
.iter()
|
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
|
||||||
.enumerate()
|
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
|
||||||
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
|
let index_addr = generator.gen_array_var_alloc(
|
||||||
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
|
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
|
||||||
if slices.len() < elts.len() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
|
|
||||||
|
|
||||||
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
ExprKind::Slice { .. } => {
|
|
||||||
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
|
|
||||||
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => {
|
|
||||||
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
|
|
||||||
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
|
||||||
} else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
|
|
||||||
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
|
|
||||||
ctx.builder.build_store(index_addr, index).unwrap();
|
|
||||||
|
|
||||||
if ndims.len() == 1 && ndims[0] == 1 {
|
|
||||||
// Accessing an element from a 1-dimensional `ndarray`
|
|
||||||
|
|
||||||
return Ok(Some(
|
|
||||||
v.data()
|
|
||||||
.get(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&ArraySliceValue::from_ptr_val(
|
|
||||||
index_addr,
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accessing an element from a multi-dimensional `ndarray`
|
|
||||||
|
|
||||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
|
||||||
// elements over
|
|
||||||
let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
|
||||||
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
|
|
||||||
|
|
||||||
let num_dims = v.load_ndims(ctx);
|
|
||||||
ndarray.store_ndims(
|
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
llvm_int_ty,
|
||||||
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(),
|
llvm_usize.const_int(elts.len() as u64, false),
|
||||||
);
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
for (i, elt) in elts.iter().enumerate() {
|
||||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
let Some(index) = generator.gen_expr(ctx, elt)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
let index = index
|
||||||
let v_dims_src_ptr = unsafe {
|
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
|
||||||
v.dim_sizes().ptr_offset_unchecked(
|
.into_int_value();
|
||||||
|
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let store_ptr = unsafe {
|
||||||
|
index_addr.ptr_offset_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(i as u64, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(index_addr)
|
||||||
|
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
|
||||||
|
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
|
||||||
|
let index_addr = generator.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_int_ty,
|
||||||
|
llvm_usize.const_int(1u64, false),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let index =
|
||||||
|
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
|
||||||
|
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
|
||||||
|
|
||||||
|
let store_ptr = unsafe {
|
||||||
|
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||||
|
|
||||||
|
Some(index_addr)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
|
||||||
|
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||||
|
|
||||||
|
v.data().get(ctx, generator, &index_addr, None).into()
|
||||||
|
} else {
|
||||||
|
match &slice.node {
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
let slices = elts
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
|
||||||
|
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
if slices.len() < elts.len() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
|
||||||
|
|
||||||
|
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
ExprKind::Slice { .. } => {
|
||||||
|
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
// Accessing an element from a multi-dimensional `ndarray`
|
||||||
|
|
||||||
|
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||||
|
|
||||||
|
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||||
|
// elements over
|
||||||
|
let subscripted_ndarray =
|
||||||
|
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
|
||||||
|
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
|
||||||
|
|
||||||
|
let num_dims = v.load_ndims(ctx);
|
||||||
|
ndarray.store_ndims(
|
||||||
ctx,
|
ctx,
|
||||||
generator,
|
generator,
|
||||||
&llvm_usize.const_int(1, false),
|
ctx.builder
|
||||||
None,
|
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
|
||||||
)
|
.unwrap(),
|
||||||
};
|
);
|
||||||
call_memcpy_generic(
|
|
||||||
ctx,
|
|
||||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
|
||||||
v_dims_src_ptr,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
|
||||||
.map(Into::into)
|
|
||||||
.unwrap(),
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
generator,
|
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
|
||||||
ctx,
|
|
||||||
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
|
||||||
(None, None),
|
|
||||||
);
|
|
||||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
|
||||||
|
|
||||||
let v_data_src_ptr = v.data().ptr_offset(
|
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||||
ctx,
|
let v_dims_src_ptr = unsafe {
|
||||||
generator,
|
v.dim_sizes().ptr_offset_unchecked(
|
||||||
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
ctx,
|
||||||
None,
|
generator,
|
||||||
);
|
&llvm_usize.const_int(1, false),
|
||||||
call_memcpy_generic(
|
None,
|
||||||
ctx,
|
)
|
||||||
ndarray.data().base_ptr(ctx, generator),
|
};
|
||||||
v_data_src_ptr,
|
call_memcpy_generic(
|
||||||
ctx.builder
|
ctx,
|
||||||
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
|
ndarray.dim_sizes().base_ptr(ctx, generator),
|
||||||
.map(Into::into)
|
v_dims_src_ptr,
|
||||||
.unwrap(),
|
ctx.builder
|
||||||
llvm_i1.const_zero(),
|
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
||||||
);
|
.map(Into::into)
|
||||||
|
.unwrap(),
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
ndarray.as_base_value().into()
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&ndarray.dim_sizes().as_slice_value(ctx, generator),
|
||||||
|
(None, None),
|
||||||
|
);
|
||||||
|
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||||
|
|
||||||
|
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ndarray.data().base_ptr(ctx, generator),
|
||||||
|
v_data_src_ptr,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_mul(
|
||||||
|
ndarray_num_elems,
|
||||||
|
llvm_ndarray_data_t.size_of().unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(Into::into)
|
||||||
|
.unwrap(),
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
ndarray.as_base_value().into()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue