1
0
forked from M-Labs/nac3

core: Fix codegen for tuple-index into ndarray

This commit is contained in:
David Mak 2024-06-20 12:24:26 +08:00 committed by sb10q
parent 635c944c90
commit 0452e6de78

View File

@ -3,8 +3,8 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
}, },
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type, gen_in_range_check, get_llvm_abi_type, get_llvm_type,
@ -1741,22 +1741,37 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndims = values let ndims = values
.iter() .iter()
.map(|ndim| match *ndim { .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
SymbolValue::U64(v) => Ok(v), .collect::<Result<Vec<_>, _>>()
SymbolValue::U32(v) => Ok(u64::from(v)), .map_err(|val| {
SymbolValue::I32(v) => u64::try_from(v) format!(
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")), "Expected non-negative literal for ndarray.ndims, got {}",
SymbolValue::I64(v) => u64::try_from(v) i128::try_from(val).unwrap()
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")), )
_ => unreachable!(), })?;
})
.collect::<Result<Vec<_>, _>>()?;
assert!(!ndims.is_empty()); assert!(!ndims.is_empty());
let ndarray_ndims_ty = ctx // The number of dimensions subscripted by the index expression.
.unifier // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
.get_fresh_literal(ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), None); // dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty = let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
@ -1859,7 +1874,72 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
} }
}; };
Ok(Some(match &slice.node { let make_indices_arr = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<_, String> {
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(elts.len() as u64, false),
None,
)?;
for (i, elt) in elts.iter().enumerate() {
let Some(index) = generator.gen_expr(ctx, elt)? else {
return Ok(None);
};
let index = index
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None);
};
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
None,
)
};
ctx.builder.build_store(store_ptr, index).unwrap();
}
Some(index_addr)
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(1u64, false),
None,
)?;
let index =
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder.build_store(store_ptr, index).unwrap();
Some(index_addr)
} else {
None
})
};
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
v.data().get(ctx, generator, &index_addr, None).into()
} else {
match &slice.node {
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
let slices = elts let slices = elts
.iter() .iter()
@ -1885,46 +1965,23 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
} }
_ => { _ => {
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None);
};
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap();
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
return Ok(Some(
v.data()
.get(
ctx,
generator,
&ArraySliceValue::from_ptr_val(
index_addr,
llvm_usize.const_int(1, false),
None,
),
None,
)
.into(),
));
}
// Accessing an element from a multi-dimensional `ndarray` // Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
// Create a new array, remove the top dimension from the dimension-size-list, and copy the // Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over // elements over
let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx); let num_dims = v.load_ndims(ctx);
ndarray.store_ndims( ndarray.store_ndims(
ctx, ctx,
generator, generator,
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(), ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
); );
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
@ -1958,18 +2015,17 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
); );
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset( let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
ctx,
generator,
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
None,
);
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
ndarray.data().base_ptr(ctx, generator), ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr, v_data_src_ptr,
ctx.builder ctx.builder
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "") .build_int_mul(
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into) .map(Into::into)
.unwrap(), .unwrap(),
llvm_i1.const_zero(), llvm_i1.const_zero(),
@ -1977,6 +2033,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndarray.as_base_value().into() ndarray.as_base_value().into()
} }
}
})) }))
} }