Implement indexing for ndarray #381
|
@ -1,12 +1,12 @@
|
|||
use inkwell::{
|
||||
IntPredicate,
|
||||
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
values::{ArrayValue, BasicValueEnum, IntValue, PointerValue},
|
||||
};
|
||||
use crate::codegen::{
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
|
||||
stmt::gen_for_callback,
|
||||
};
|
||||
|
||||
|
@ -725,7 +725,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||
panic!("Expected list[int32] but got {indices_elem_ty}")
|
||||
};
|
||||
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
|
||||
debug_assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
|
||||
|
||||
let index = call_ndarray_flatten_index(
|
||||
generator,
|
||||
|
@ -743,6 +743,92 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
pub unsafe fn ptr_offset_unchecked_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let index = call_ndarray_flatten_index_const(
|
||||
generator,
|
||||
ctx,
|
||||
self.0,
|
||||
indices,
|
||||
).unwrap();
|
||||
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.get_ptr(ctx),
|
||||
&[index],
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pointer to the data at the index specified by `indices`.
|
||||
pub fn ptr_offset_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let indices_elem_ty = indices.get_type().get_element_type();
|
||||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||
panic!("Expected [int32] but got [{indices_elem_ty}]")
|
||||
};
|
||||
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected [int32] but got [{indices_elem_ty}]");
|
||||
|
||||
let nidx_leq_ndims = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLE,
|
||||
llvm_usize.const_int(indices.get_type().len() as u64, false),
|
||||
self.0.load_ndims(ctx),
|
||||
""
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
nidx_leq_ndims,
|
||||
"0:IndexError",
|
||||
"invalid index to scalar variable",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
for idx in 0..indices.get_type().len() {
|
||||
let i = llvm_usize.const_int(idx as u64, false);
|
||||
|
||||
let dim_idx = ctx.builder
|
||||
.build_extract_value(indices, idx, "")
|
||||
.map(|v| v.into_int_value())
|
||||
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, ""))
|
||||
.unwrap();
|
||||
let dim_sz = self.0.get_dims().get(ctx, generator, i, None);
|
||||
|
||||
let dim_lt = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim_idx,
|
||||
dim_sz,
|
||||
""
|
||||
);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
dim_lt,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(dim_idx), Some(dim_sz), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
self.ptr_offset_unchecked_const(ctx, generator, indices, name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pointer to the data at the index specified by `indices`.
|
||||
pub fn ptr_offset(
|
||||
&self,
|
||||
|
@ -844,6 +930,17 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
pub unsafe fn get_unsafe_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset_unchecked_const(ctx, generator, indices, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
pub unsafe fn get_unsafe(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -855,6 +952,18 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
|
|||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Returns the data at the index specified by `indices`.
|
||||
pub fn get_const(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ArrayValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset_const(ctx, generator, indices, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Returns the data at the index specified by `indices`.
|
||||
pub fn get(
|
||||
&self,
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::{ListValue, RangeValue},
|
||||
classes::{ListValue, NDArrayValue, RangeValue},
|
||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||
gen_in_range_check,
|
||||
get_llvm_type,
|
||||
|
@ -1190,6 +1190,213 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||
}
|
||||
}
|
||||
|
||||
/// Generates code for a subscript expression on an `ndarray`.
|
||||
///
|
||||
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||
/// * `v` - The `NDArray` value.
|
||||
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
||||
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
ndims: Type,
|
||||
v: NDArrayValue<'ctx>,
|
||||
slice: &Expr<Option<Type>>,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let ndims = values.iter()
|
||||
.map(|ndim| match *ndim {
|
||||
SymbolValue::U64(v) => Ok(v),
|
||||
SymbolValue::U32(v) => Ok(v as u64),
|
||||
SymbolValue::I32(v) => u64::try_from(v)
|
||||
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
||||
SymbolValue::I64(v) => u64::try_from(v)
|
||||
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
assert!(!ndims.is_empty());
|
||||
|
||||
let ndarray_ty_enum = TypeEnum::TNDArray {
|
||||
ty,
|
||||
ndims: ctx.unifier.get_fresh_literal(
|
||||
ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
||||
None,
|
||||
),
|
||||
};
|
||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||
|
||||
// Check that len is non-zero
|
||||
let len = v.load_ndims(ctx);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), ""),
|
||||
"0:IndexError",
|
||||
"too many indices for array: array is {0}-dimensional but 1 were indexed",
|
||||
[Some(len), None, None],
|
||||
slice.location,
|
||||
);
|
||||
|
||||
if ndims.len() == 1 && ndims[0] == 1 {
|
||||
// Accessing an element from a 1-dimensional `ndarray`
|
||||
|
||||
if let ExprKind::Slice { .. } = &slice.node {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
|
||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
|
||||
Ok(Some(v.get_data()
|
||||
.get_const(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.ctx.i32_type().const_array(&[index]),
|
||||
None,
|
||||
)
|
||||
.into()))
|
||||
} else {
|
||||
// Accessing an element from a multi-dimensional `ndarray`
|
||||
|
||||
if let ExprKind::Slice { .. } = &slice.node {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
|
||||
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
|
||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||
// elements over
|
||||
let subscripted_ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(
|
||||
subscripted_ndarray,
|
||||
llvm_usize,
|
||||
None
|
||||
);
|
||||
|
||||
let num_dims = v.load_ndims(ctx);
|
||||
ndarray.store_ndims(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), ""),
|
||||
);
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
let memcpy_fn_name = format!(
|
||||
"llvm.memcpy.p0i8.p0i8.i{}",
|
||||
generator.get_size_type(ctx.ctx).get_bit_width(),
|
||||
);
|
||||
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[
|
||||
llvm_pi8.into(),
|
||||
llvm_pi8.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_i1.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let v_dims_src_ptr = v.get_dims().ptr_offset(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_int(1, false),
|
||||
None,
|
||||
);
|
||||
ctx.builder.build_call(
|
||||
memcpy_fn,
|
||||
&[
|
||||
ctx.builder.build_bitcast(
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_bitcast(
|
||||
v_dims_src_ptr,
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_int_mul(
|
||||
ndarray_num_dims.into(),
|
||||
llvm_usize.size_of(),
|
||||
"",
|
||||
).into(),
|
||||
llvm_i1.const_zero().into(),
|
||||
],
|
||||
"",
|
||||
);
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
let v_data_src_ptr = v.get_data().ptr_offset_const(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.ctx.i32_type().const_array(&[index]),
|
||||
None
|
||||
);
|
||||
ctx.builder.build_call(
|
||||
memcpy_fn,
|
||||
&[
|
||||
ctx.builder.build_bitcast(
|
||||
ndarray.get_data().get_ptr(ctx),
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_bitcast(
|
||||
v_data_src_ptr,
|
||||
llvm_pi8,
|
||||
"",
|
||||
).into(),
|
||||
ctx.builder.build_int_mul(
|
||||
ndarray_num_elems.into(),
|
||||
llvm_ndarray_data_t.size_of().unwrap(),
|
||||
"",
|
||||
).into(),
|
||||
llvm_i1.const_zero().into(),
|
||||
],
|
||||
"",
|
||||
);
|
||||
|
||||
Ok(Some(v.get_ptr().into()))
|
||||
}
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_expr`].
|
||||
pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
|
@ -1810,8 +2017,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
v.get_data().get(ctx, generator, index, None).into()
|
||||
}
|
||||
}
|
||||
TypeEnum::TNDArray { .. } => {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
TypeEnum::TNDArray { ty, ndims } => {
|
||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let v = NDArrayValue::from_ptr_val(v, usize, None);
|
||||
|
||||
return gen_ndarray_subscript_expr(
|
||||
generator,
|
||||
ctx,
|
||||
*ty,
|
||||
*ndims,
|
||||
v,
|
||||
&*slice,
|
||||
)
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let index: u32 =
|
||||
|
|
|
@ -278,12 +278,13 @@ uint32_t __nac3_ndarray_flatten_index(
|
|||
) {
|
||||
uint32_t idx = 0;
|
||||
uint32_t stride = 1;
|
||||
for (uint32_t i = num_dims - 1; i-- >= 0; ) {
|
||||
if (i < num_indices) {
|
||||
idx += (stride * indices[i]);
|
||||
for (uint32_t i = 0; i < num_dims; ++i) {
|
||||
uint32_t ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += (stride * indices[ri]);
|
||||
}
|
||||
|
||||
stride *= dims[i];
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
@ -296,12 +297,13 @@ uint64_t __nac3_ndarray_flatten_index64(
|
|||
) {
|
||||
uint64_t idx = 0;
|
||||
uint64_t stride = 1;
|
||||
for (uint64_t i = num_dims - 1; i-- >= 0; ) {
|
||||
if (i < num_indices) {
|
||||
idx += (stride * indices[i]);
|
||||
for (uint64_t i = 0; i < num_dims; ++i) {
|
||||
uint64_t ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += (stride * indices[ri]);
|
||||
}
|
||||
|
||||
stride *= dims[i];
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
|
|
@ -10,8 +10,8 @@ use inkwell::{
|
|||
context::Context,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::Module,
|
||||
types::BasicTypeEnum,
|
||||
values::{FloatValue, IntValue, PointerValue},
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{ArrayValue, FloatValue, IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use nac3parser::ast::Expr;
|
||||
|
@ -707,7 +707,75 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
Ok(indices)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`.
|
||||
fn call_ndarray_flatten_index_impl<'ctx>(
|
||||
generator: &dyn CodeGenerator,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: PointerValue<'ctx>,
|
||||
indices_size: IntValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
debug_assert_eq!(
|
||||
IntType::try_from(indices.get_type().get_element_type())
|
||||
.map(|itype| itype.get_bit_width())
|
||||
.unwrap_or_default(),
|
||||
llvm_i32.get_bit_width(),
|
||||
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
debug_assert_eq!(
|
||||
indices_size.get_type().get_bit_width(),
|
||||
llvm_usize.get_bit_width(),
|
||||
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
|
||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_flatten_index",
|
||||
64 => "__nac3_ndarray_flatten_index64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pi32.into(),
|
||||
llvm_usize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.get_dims();
|
||||
|
||||
let index = ctx.builder
|
||||
.build_call(
|
||||
ndarray_flatten_index_fn,
|
||||
&[
|
||||
ndarray_dims.get_ptr(ctx).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.into(),
|
||||
indices_size.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.try_as_basic_value()
|
||||
.map_left(|v| v.into_int_value())
|
||||
.left()
|
||||
.unwrap();
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||
/// multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
|
@ -718,51 +786,57 @@ pub fn call_ndarray_flatten_index<'ctx>(
|
|||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: ListValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_flatten_index",
|
||||
64 => "__nac3_ndarray_flatten_index64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_pi32.into(),
|
||||
llvm_pusize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.get_dims();
|
||||
let indices_size = indices.load_size(ctx, None);
|
||||
let indices_data = indices.get_data();
|
||||
|
||||
let index = ctx.builder
|
||||
.build_call(
|
||||
ndarray_flatten_index_fn,
|
||||
&[
|
||||
ndarray_num_dims.into(),
|
||||
ndarray_dims.get_ptr(ctx).into(),
|
||||
indices_size.into(),
|
||||
indices_data.get_ptr(ctx).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.try_as_basic_value()
|
||||
.map_left(|v| v.into_int_value())
|
||||
.left()
|
||||
.unwrap();
|
||||
call_ndarray_flatten_index_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
indices_data.get_ptr(ctx),
|
||||
indices_size,
|
||||
)
|
||||
}
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||
/// multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index_const<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: ArrayValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
let indices_size = indices.get_type().len();
|
||||
let indices_alloca = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
indices.get_type().get_element_type(),
|
||||
llvm_usize.const_int(indices_size as u64, false),
|
||||
None
|
||||
)?;
|
||||
for i in 0..indices_size {
|
||||
let v = ctx.builder.build_extract_value(indices, i, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let elem_ptr = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
indices_alloca,
|
||||
&[ctx.ctx.i32_type().const_int(i as u64, false)],
|
||||
""
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(elem_ptr, v);
|
||||
}
|
||||
|
||||
call_ndarray_flatten_index_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
indices_alloca,
|
||||
llvm_usize.const_int(indices_size as u64, false),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -1237,6 +1237,67 @@ impl<'a> Inferencer<'a> {
|
|||
Ok(boolean)
|
||||
}
|
||||
|
||||
/// Infers the type of a subscript expression on an `ndarray`.
|
||||
fn infer_subscript_ndarray(
|
||||
&mut self,
|
||||
value: &ast::Expr<Option<Type>>,
|
||||
dummy_tvar: Type,
|
||||
ndims: &Type,
|
||||
) -> InferenceResult {
|
||||
debug_assert!(matches!(
|
||||
&*self.unifier.get_ty_immutable(dummy_tvar),
|
||||
TypeEnum::TVar { is_const_generic: false, .. }
|
||||
));
|
||||
|
||||
let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims: *ndims });
|
||||
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(*ndims) else {
|
||||
panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(*ndims))
|
||||
};
|
||||
|
||||
let ndims = values.iter()
|
||||
.map(|ndim| match *ndim {
|
||||
SymbolValue::U64(v) => Ok(v),
|
||||
SymbolValue::U32(v) => Ok(v as u64),
|
||||
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
||||
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
|
||||
])),
|
||||
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
||||
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
|
||||
])),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
assert!(!ndims.is_empty());
|
||||
|
||||
if ndims.len() == 1 && ndims[0] == 1 {
|
||||
// ndarray[T, Literal[1]] - Index always returns an object of type T
|
||||
|
||||
assert_ne!(ndims[0], 0);
|
||||
|
||||
Ok(dummy_tvar)
|
||||
} else {
|
||||
// ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]]
|
||||
|
||||
if ndims.iter().any(|v| *v == 0) {
|
||||
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
|
||||
}
|
||||
|
||||
let ndims_min_one_ty = self.unifier.get_fresh_literal(
|
||||
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
||||
None,
|
||||
);
|
||||
let subscripted_ty = self.unifier.add_ty(TypeEnum::TNDArray {
|
||||
ty: dummy_tvar,
|
||||
ndims: ndims_min_one_ty,
|
||||
});
|
||||
|
||||
Ok(subscripted_ty)
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_subscript(
|
||||
&mut self,
|
||||
value: &ast::Expr<Option<Type>>,
|
||||
|
@ -1258,33 +1319,41 @@ impl<'a> Inferencer<'a> {
|
|||
Ok(list_like_ty)
|
||||
}
|
||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||
// the index is a constant, so value can be a sequence.
|
||||
let ind: Option<i32> = (*val).try_into().ok();
|
||||
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
|
||||
let map = once((
|
||||
ind.into(),
|
||||
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
|
||||
))
|
||||
.collect();
|
||||
let seq = self.unifier.add_record(map);
|
||||
self.constrain(value.custom.unwrap(), seq, &value.location)?;
|
||||
Ok(ty)
|
||||
if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
self.infer_subscript_ndarray(value, ty, ndims)
|
||||
} else {
|
||||
// the index is a constant, so value can be a sequence.
|
||||
let ind: Option<i32> = (*val).try_into().ok();
|
||||
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
|
||||
let map = once((
|
||||
ind.into(),
|
||||
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
|
||||
))
|
||||
.collect();
|
||||
let seq = self.unifier.add_record(map);
|
||||
self.constrain(value.custom.unwrap(), seq, &value.location)?;
|
||||
Ok(ty)
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap())
|
||||
{
|
||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
||||
}
|
||||
|
||||
// the index is not a constant, so value can only be a list
|
||||
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||
TypeEnum::TNDArray { .. } => todo!(),
|
||||
// the index is not a constant, so value can only be a list-like structure
|
||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TList { .. } => {
|
||||
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
||||
Ok(ty)
|
||||
}
|
||||
TypeEnum::TNDArray { ndims, .. } => {
|
||||
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
|
||||
self.infer_subscript_ndarray(value, ty, ndims)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
||||
Ok(ty)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,11 @@
|
|||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||
pass
|
||||
|
||||
|
@ -17,20 +25,27 @@ def test_ndarray_empty():
|
|||
|
||||
def test_ndarray_zeros():
|
||||
n: ndarray[float, 1] = np_zeros([1])
|
||||
output_float64(n[0])
|
||||
consume_ndarray_1(n)
|
||||
|
||||
def test_ndarray_ones():
|
||||
n: ndarray[float, 1] = np_ones([1])
|
||||
output_float64(n[0])
|
||||
consume_ndarray_1(n)
|
||||
|
||||
def test_ndarray_full():
|
||||
n_float: ndarray[float, 1] = np_full([1], 2.0)
|
||||
output_float64(n_float[0])
|
||||
consume_ndarray_1(n_float)
|
||||
n_i32: ndarray[int32, 1] = np_full([1], 2)
|
||||
output_int32(n_i32[0])
|
||||
consume_ndarray_i32_1(n_i32)
|
||||
|
||||
def test_ndarray_eye():
|
||||
n: ndarray[float, 2] = np_eye(2)
|
||||
n0: ndarray[float, 1] = n[0]
|
||||
v: float = n0[0]
|
||||
output_float64(v)
|
||||
consume_ndarray_2(n)
|
||||
|
||||
def test_ndarray_identity():
|
||||
|
|
Loading…
Reference in New Issue