Implement indexing for ndarray #381

Merged
sb10q merged 6 commits from enhance/issue-149-ndarray into master 2024-02-19 17:36:10 +08:00
6 changed files with 572 additions and 82 deletions

View File

@ -1,12 +1,12 @@
use inkwell::{
IntPredicate,
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
values::{ArrayValue, BasicValueEnum, IntValue, PointerValue},
};
use crate::codegen::{
CodeGenContext,
CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
stmt::gen_for_callback,
};
@ -725,7 +725,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
debug_assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
let index = call_ndarray_flatten_index(
generator,
@ -743,6 +743,92 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
}
}
pub unsafe fn ptr_offset_unchecked_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let index = call_ndarray_flatten_index_const(
generator,
ctx,
self.0,
indices,
).unwrap();
unsafe {
ctx.builder.build_in_bounds_gep(
self.get_ptr(ctx),
&[index],
name.unwrap_or_default(),
)
}
}
/// Returns the pointer to the data at the index specified by `indices`.
pub fn ptr_offset_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices.get_type().get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected [int32] but got [{indices_elem_ty}]")
};
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected [int32] but got [{indices_elem_ty}]");
let nidx_leq_ndims = ctx.builder.build_int_compare(
IntPredicate::SLE,
llvm_usize.const_int(indices.get_type().len() as u64, false),
self.0.load_ndims(ctx),
""
);
ctx.make_assert(
generator,
nidx_leq_ndims,
"0:IndexError",
"invalid index to scalar variable",
[None, None, None],
ctx.current_loc,
);
for idx in 0..indices.get_type().len() {
let i = llvm_usize.const_int(idx as u64, false);
let dim_idx = ctx.builder
.build_extract_value(indices, idx, "")
.map(|v| v.into_int_value())
.map(|v| ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, ""))
.unwrap();
let dim_sz = self.0.get_dims().get(ctx, generator, i, None);
let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT,
dim_idx,
dim_sz,
""
);
ctx.make_assert(
generator,
dim_lt,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(dim_idx), Some(dim_sz), None],
ctx.current_loc,
);
}
unsafe {
self.ptr_offset_unchecked_const(ctx, generator, indices, name)
}
}
/// Returns the pointer to the data at the index specified by `indices`.
pub fn ptr_offset(
&self,
@ -844,6 +930,17 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
}
}
pub unsafe fn get_unsafe_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset_unchecked_const(ctx, generator, indices, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
pub unsafe fn get_unsafe(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -855,6 +952,18 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
/// Returns the data at the index specified by `indices`.
pub fn get_const(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
indices: ArrayValue<'ctx>,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset_const(ctx, generator, indices, name);
ctx.builder.build_load(ptr, name.unwrap_or_default())
}
/// Returns the data at the index specified by `indices`.
pub fn get(
&self,

View File

@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{
codegen::{
classes::{ListValue, RangeValue},
classes::{ListValue, NDArrayValue, RangeValue},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check,
get_llvm_type,
@ -1190,6 +1190,213 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
}
}
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `v` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
ndims: Type,
v: NDArrayValue<'ctx>,
slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_void = ctx.ctx.void_type();
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
unreachable!()
};
let ndims = values.iter()
.map(|ndim| match *ndim {
SymbolValue::U64(v) => Ok(v),
SymbolValue::U32(v) => Ok(v as u64),
SymbolValue::I32(v) => u64::try_from(v)
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
SymbolValue::I64(v) => u64::try_from(v)
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()?;
assert!(!ndims.is_empty());
let ndarray_ty_enum = TypeEnum::TNDArray {
ty,
ndims: ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(),
None,
),
};
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
// Check that len is non-zero
let len = v.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), ""),
"0:IndexError",
"too many indices for array: array is {0}-dimensional but 1 were indexed",
[Some(len), None, None],
slice.location,
);
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
if let ExprKind::Slice { .. } = &slice.node {
return Err(String::from("subscript operator for ndarray not implemented"))
}
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None)
};
Ok(Some(v.get_data()
.get_const(
ctx,
generator,
ctx.ctx.i32_type().const_array(&[index]),
None,
)
.into()))
} else {
// Accessing an element from a multi-dimensional `ndarray`
if let ExprKind::Slice { .. } = &slice.node {
return Err(String::from("subscript operator for ndarray not implemented"))
}
let index = if let Some(v) = generator.gen_expr(ctx, slice)? {
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None)
};
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None
)?;
let ndarray = NDArrayValue::from_ptr_val(
subscripted_ndarray,
llvm_usize,
None
);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), ""),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
let memcpy_fn_name = format!(
"llvm.memcpy.p0i8.p0i8.i{}",
generator.get_size_type(ctx.ctx).get_bit_width(),
);
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pi8.into(),
llvm_pi8.into(),
llvm_usize.into(),
llvm_i1.into(),
],
false,
);
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = v.get_dims().ptr_offset(
ctx,
generator,
llvm_usize.const_int(1, false),
None,
);
ctx.builder.build_call(
memcpy_fn,
&[
ctx.builder.build_bitcast(
ndarray.get_dims().get_ptr(ctx),
llvm_pi8,
"",
).into(),
ctx.builder.build_bitcast(
v_dims_src_ptr,
llvm_pi8,
"",
).into(),
ctx.builder.build_int_mul(
ndarray_num_dims.into(),
llvm_usize.size_of(),
"",
).into(),
llvm_i1.const_zero().into(),
],
"",
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.get_dims().get_ptr(ctx),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.get_data().ptr_offset_const(
ctx,
generator,
ctx.ctx.i32_type().const_array(&[index]),
None
);
ctx.builder.build_call(
memcpy_fn,
&[
ctx.builder.build_bitcast(
ndarray.get_data().get_ptr(ctx),
llvm_pi8,
"",
).into(),
ctx.builder.build_bitcast(
v_data_src_ptr,
llvm_pi8,
"",
).into(),
ctx.builder.build_int_mul(
ndarray_num_elems.into(),
llvm_ndarray_data_t.size_of().unwrap(),
"",
).into(),
llvm_i1.const_zero().into(),
],
"",
);
Ok(Some(v.get_ptr().into()))
}
}
/// See [`CodeGenerator::gen_expr`].
pub fn gen_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
@ -1810,8 +2017,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
v.get_data().get(ctx, generator, index, None).into()
}
}
TypeEnum::TNDArray { .. } => {
return Err(String::from("subscript operator for ndarray not implemented"))
TypeEnum::TNDArray { ty, ndims } => {
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
} else {
return Ok(None)
};
let v = NDArrayValue::from_ptr_val(v, usize, None);
return gen_ndarray_subscript_expr(
generator,
ctx,
*ty,
*ndims,
v,
&*slice,
)
}
TypeEnum::TTuple { .. } => {
let index: u32 =

View File

@ -278,12 +278,13 @@ uint32_t __nac3_ndarray_flatten_index(
) {
uint32_t idx = 0;
uint32_t stride = 1;
for (uint32_t i = num_dims - 1; i-- >= 0; ) {
if (i < num_indices) {
idx += (stride * indices[i]);
for (uint32_t i = 0; i < num_dims; ++i) {
uint32_t ri = num_dims - i - 1;
if (ri < num_indices) {
idx += (stride * indices[ri]);
}
stride *= dims[i];
stride *= dims[ri];
}
return idx;
}
@ -296,12 +297,13 @@ uint64_t __nac3_ndarray_flatten_index64(
) {
uint64_t idx = 0;
uint64_t stride = 1;
for (uint64_t i = num_dims - 1; i-- >= 0; ) {
if (i < num_indices) {
idx += (stride * indices[i]);
for (uint64_t i = 0; i < num_dims; ++i) {
uint64_t ri = num_dims - i - 1;
if (ri < num_indices) {
idx += (stride * indices[ri]);
}
stride *= dims[i];
stride *= dims[ri];
}
return idx;
}

View File

@ -10,8 +10,8 @@ use inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
types::BasicTypeEnum,
values::{FloatValue, IntValue, PointerValue},
types::{BasicTypeEnum, IntType},
values::{ArrayValue, FloatValue, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use nac3parser::ast::Expr;
@ -707,7 +707,75 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
Ok(indices)
}
/// Generates a call to `__nac3_ndarray_flatten_index`.
fn call_ndarray_flatten_index_impl<'ctx>(
generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: PointerValue<'ctx>,
indices_size: IntValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.get_type().get_element_type())
.map(|itype| itype.get_bit_width())
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices_size.get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.get_dims();
let index = ctx.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.get_ptr(ctx).into(),
ndarray_num_dims.into(),
indices.into(),
indices_size.into(),
],
"",
)
.try_as_basic_value()
.map_left(|v| v.into_int_value())
.left()
.unwrap();
Ok(index)
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
@ -718,51 +786,57 @@ pub fn call_ndarray_flatten_index<'ctx>(
ndarray: NDArrayValue<'ctx>,
indices: ListValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_usize.into(),
llvm_pusize.into(),
llvm_pi32.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.get_dims();
let indices_size = indices.load_size(ctx, None);
let indices_data = indices.get_data();
let index = ctx.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_num_dims.into(),
ndarray_dims.get_ptr(ctx).into(),
indices_size.into(),
indices_data.get_ptr(ctx).into(),
],
"",
)
.try_as_basic_value()
.map_left(|v| v.into_int_value())
.left()
.unwrap();
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_data.get_ptr(ctx),
indices_size,
)
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index_const<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ArrayValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
Ok(index)
}
let indices_size = indices.get_type().len();
let indices_alloca = generator.gen_array_var_alloc(
ctx,
indices.get_type().get_element_type(),
llvm_usize.const_int(indices_size as u64, false),
None
)?;
for i in 0..indices_size {
let v = ctx.builder.build_extract_value(indices, i, "")
.unwrap()
.into_int_value();
let elem_ptr = unsafe {
ctx.builder.build_in_bounds_gep(
indices_alloca,
&[ctx.ctx.i32_type().const_int(i as u64, false)],
""
)
};
ctx.builder.build_store(elem_ptr, v);
}
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_alloca,
llvm_usize.const_int(indices_size as u64, false),
)
}

View File

@ -1237,6 +1237,67 @@ impl<'a> Inferencer<'a> {
Ok(boolean)
}
/// Infers the type of a subscript expression on an `ndarray`.
fn infer_subscript_ndarray(
&mut self,
value: &ast::Expr<Option<Type>>,
dummy_tvar: Type,
ndims: &Type,
) -> InferenceResult {
debug_assert!(matches!(
&*self.unifier.get_ty_immutable(dummy_tvar),
TypeEnum::TVar { is_const_generic: false, .. }
));
let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims: *ndims });
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(*ndims) else {
panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(*ndims))
};
let ndims = values.iter()
.map(|ndim| match *ndim {
SymbolValue::U64(v) => Ok(v),
SymbolValue::U32(v) => Ok(v as u64),
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
])),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
])),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()?;
assert!(!ndims.is_empty());
if ndims.len() == 1 && ndims[0] == 1 {
// ndarray[T, Literal[1]] - Index always returns an object of type T
assert_ne!(ndims[0], 0);
Ok(dummy_tvar)
} else {
// ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]]
if ndims.iter().any(|v| *v == 0) {
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
}
let ndims_min_one_ty = self.unifier.get_fresh_literal(
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(),
None,
);
let subscripted_ty = self.unifier.add_ty(TypeEnum::TNDArray {
ty: dummy_tvar,
ndims: ndims_min_one_ty,
});
Ok(subscripted_ty)
}
}
fn infer_subscript(
&mut self,
value: &ast::Expr<Option<Type>>,
@ -1258,33 +1319,41 @@ impl<'a> Inferencer<'a> {
Ok(list_like_ty)
}
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
// the index is a constant, so value can be a sequence.
let ind: Option<i32> = (*val).try_into().ok();
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
let map = once((
ind.into(),
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
))
.collect();
let seq = self.unifier.add_record(map);
self.constrain(value.custom.unwrap(), seq, &value.location)?;
Ok(ty)
if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
self.infer_subscript_ndarray(value, ty, ndims)
} else {
// the index is a constant, so value can be a sequence.
let ind: Option<i32> = (*val).try_into().ok();
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
let map = once((
ind.into(),
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
))
.collect();
let seq = self.unifier.add_record(map);
self.constrain(value.custom.unwrap(), seq, &value.location)?;
Ok(ty)
}
}
_ => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap())
{
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
}
// the index is not a constant, so value can only be a list
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TNDArray { .. } => todo!(),
// the index is not a constant, so value can only be a list-like structure
match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => {
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list, &value.location)?;
Ok(ty)
}
TypeEnum::TNDArray { ndims, .. } => {
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims)
}
_ => unreachable!(),
};
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
Ok(ty)
}
}
}
}

View File

@ -1,3 +1,11 @@
@extern
def output_int32(x: int32):
...
@extern
def output_float64(x: float):
...
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
@ -17,20 +25,27 @@ def test_ndarray_empty():
def test_ndarray_zeros():
n: ndarray[float, 1] = np_zeros([1])
output_float64(n[0])
consume_ndarray_1(n)
def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1])
output_float64(n[0])
consume_ndarray_1(n)
def test_ndarray_full():
n_float: ndarray[float, 1] = np_full([1], 2.0)
output_float64(n_float[0])
consume_ndarray_1(n_float)
n_i32: ndarray[int32, 1] = np_full([1], 2)
output_int32(n_i32[0])
consume_ndarray_i32_1(n_i32)
def test_ndarray_eye():
n: ndarray[float, 2] = np_eye(2)
n0: ndarray[float, 1] = n[0]
v: float = n0[0]
output_float64(v)
consume_ndarray_2(n)
def test_ndarray_identity():