Compare commits
4 Commits
669c6aca6b
...
3a8c385e01
Author | SHA1 | Date |
---|---|---|
lyken | 3a8c385e01 | |
lyken | 221de4d06a | |
lyken | fb9fe8edf2 | |
lyken | 894083c6a3 |
|
@ -995,8 +995,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
ctx.builder.position_at_end(init_bb);
|
ctx.builder.position_at_end(init_bb);
|
||||||
|
|
||||||
let Comprehension { target, iter, ifs, .. } = &generators[0];
|
let Comprehension { target, iter, ifs, .. } = &generators[0];
|
||||||
|
|
||||||
|
let iter_ty = iter.custom.unwrap();
|
||||||
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
||||||
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
|
v.to_basic_value_enum(ctx, generator, iter_ty)?
|
||||||
} else {
|
} else {
|
||||||
for bb in [test_bb, body_bb, cont_bb] {
|
for bb in [test_bb, body_bb, cont_bb] {
|
||||||
ctx.builder.position_at_end(bb);
|
ctx.builder.position_at_end(bb);
|
||||||
|
@ -1014,10 +1016,12 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
ctx.builder.build_store(index, zero_size_t).unwrap();
|
ctx.builder.build_store(index, zero_size_t).unwrap();
|
||||||
|
|
||||||
let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap());
|
let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap());
|
||||||
let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
|
|
||||||
let list;
|
let list;
|
||||||
|
|
||||||
if is_range {
|
match &*ctx.unifier.get_ty(iter_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
||||||
let (start, stop, step) = destructure_range(ctx, iter_val);
|
let (start, stop, step) = destructure_range(ctx, iter_val);
|
||||||
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
|
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
|
||||||
|
@ -1025,7 +1029,8 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
// the length may be 1 more than the actual length if the division is exact, but the
|
// the length may be 1 more than the actual length if the division is exact, but the
|
||||||
// length is a upper bound only anyway so it does not matter.
|
// length is a upper bound only anyway so it does not matter.
|
||||||
let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap();
|
let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap();
|
||||||
let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap();
|
let length =
|
||||||
|
ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap();
|
||||||
// in case length is non-positive
|
// in case length is non-positive
|
||||||
let is_valid =
|
let is_valid =
|
||||||
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap();
|
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap();
|
||||||
|
@ -1034,7 +1039,9 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
.builder
|
.builder
|
||||||
.build_select(
|
.build_select(
|
||||||
is_valid,
|
is_valid,
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(),
|
ctx.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len")
|
||||||
|
.unwrap(),
|
||||||
zero_size_t,
|
zero_size_t,
|
||||||
"listcomp.alloc_size",
|
"listcomp.alloc_size",
|
||||||
)
|
)
|
||||||
|
@ -1053,7 +1060,11 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_conditional_branch(gen_in_range_check(ctx, start, stop, step), test_bb, cont_bb)
|
.build_conditional_branch(
|
||||||
|
gen_in_range_check(ctx, start, stop, step),
|
||||||
|
test_bb,
|
||||||
|
cont_bb,
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(test_bb);
|
ctx.builder.position_at_end(test_bb);
|
||||||
|
@ -1068,11 +1079,18 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
ctx.builder.build_store(i, tmp).unwrap();
|
ctx.builder.build_store(i, tmp).unwrap();
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_conditional_branch(gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb)
|
.build_conditional_branch(
|
||||||
|
gen_in_range_check(ctx, tmp, stop, step),
|
||||||
|
body_bb,
|
||||||
|
cont_bb,
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(body_bb);
|
ctx.builder.position_at_end(body_bb);
|
||||||
} else {
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
let length = ctx
|
let length = ctx
|
||||||
.build_gep_and_load(
|
.build_gep_and_load(
|
||||||
iter_val.into_pointer_value(),
|
iter_val.into_pointer_value(),
|
||||||
|
@ -1088,7 +1106,8 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
ctx.builder.build_unconditional_branch(test_bb).unwrap();
|
ctx.builder.build_unconditional_branch(test_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(test_bb);
|
ctx.builder.position_at_end(test_bb);
|
||||||
let tmp = ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap();
|
let tmp =
|
||||||
|
ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap();
|
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap();
|
||||||
ctx.builder.build_store(counter, tmp).unwrap();
|
ctx.builder.build_store(counter, tmp).unwrap();
|
||||||
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap();
|
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap();
|
||||||
|
@ -1103,7 +1122,14 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
)
|
)
|
||||||
.into_pointer_value();
|
.into_pointer_value();
|
||||||
let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val"));
|
let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val"));
|
||||||
generator.gen_assign(ctx, target, val.into())?;
|
generator.gen_assign(ctx, target, val.into(), elt.custom.unwrap())?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!(
|
||||||
|
"unsupported list comprehension iterator type: {}",
|
||||||
|
ctx.unifier.stringify(iter_ty)
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emits the content of `cont_bb`
|
// Emits the content of `cont_bb`
|
||||||
|
|
|
@ -123,11 +123,45 @@ pub trait CodeGenerator {
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
target: &Expr<Option<Type>>,
|
target: &Expr<Option<Type>>,
|
||||||
value: ValueEnum<'ctx>,
|
value: ValueEnum<'ctx>,
|
||||||
|
value_ty: Type,
|
||||||
) -> Result<(), String>
|
) -> Result<(), String>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
{
|
{
|
||||||
gen_assign(self, ctx, target, value)
|
gen_assign(self, ctx, target, value, value_ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate code for an assignment expression where LHS is a `"target_list"`.
|
||||||
|
///
|
||||||
|
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
|
||||||
|
fn gen_assign_target_list<'ctx>(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
targets: &Vec<Expr<Option<Type>>>,
|
||||||
|
value: ValueEnum<'ctx>,
|
||||||
|
value_ty: Type,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
gen_assign_target_list(self, ctx, targets, value, value_ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate code for an item assignment.
|
||||||
|
///
|
||||||
|
/// i.e., `target[key] = value`
|
||||||
|
fn gen_setitem<'ctx>(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
target: &Expr<Option<Type>>,
|
||||||
|
key: &Expr<Option<Type>>,
|
||||||
|
value: ValueEnum<'ctx>,
|
||||||
|
value_ty: Type,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
gen_setitem(self, ctx, target, key, value, value_ty)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate code for a while expression.
|
/// Generate code for a while expression.
|
||||||
|
|
|
@ -10,10 +10,10 @@ use crate::{
|
||||||
expr::gen_binop_expr,
|
expr::gen_binop_expr,
|
||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
},
|
},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
toplevel::{DefinitionId, TopLevelDef},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::Binop,
|
magic_methods::Binop,
|
||||||
typedef::{FunSignature, Type, TypeEnum},
|
typedef::{iter_type_vars, FunSignature, Type, TypeEnum},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
|
@ -23,10 +23,10 @@ use inkwell::{
|
||||||
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||||
IntPredicate,
|
IntPredicate,
|
||||||
};
|
};
|
||||||
|
use itertools::{izip, Itertools};
|
||||||
use nac3parser::ast::{
|
use nac3parser::ast::{
|
||||||
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
|
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
|
||||||
};
|
};
|
||||||
use std::convert::TryFrom;
|
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_var_alloc`].
|
/// See [`CodeGenerator::gen_var_alloc`].
|
||||||
pub fn gen_var<'ctx>(
|
pub fn gen_var<'ctx>(
|
||||||
|
@ -97,8 +97,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
pattern: &Expr<Option<Type>>,
|
pattern: &Expr<Option<Type>>,
|
||||||
name: Option<&str>,
|
name: Option<&str>,
|
||||||
) -> Result<Option<PointerValue<'ctx>>, String> {
|
) -> Result<Option<PointerValue<'ctx>>, String> {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
// very similar to gen_expr, but we don't do an extra load at the end
|
// very similar to gen_expr, but we don't do an extra load at the end
|
||||||
// and we flatten nested tuples
|
// and we flatten nested tuples
|
||||||
Ok(Some(match &pattern.node {
|
Ok(Some(match &pattern.node {
|
||||||
|
@ -137,65 +135,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
ExprKind::Subscript { value, slice, .. } => {
|
|
||||||
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
|
|
||||||
let v = generator
|
|
||||||
.gen_expr(ctx, value)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
|
||||||
.into_pointer_value();
|
|
||||||
let v = ListValue::from_ptr_val(v, llvm_usize, None);
|
|
||||||
let len = v.load_size(ctx, Some("len"));
|
|
||||||
let raw_index = generator
|
|
||||||
.gen_expr(ctx, slice)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
|
|
||||||
.into_int_value();
|
|
||||||
let raw_index = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext")
|
|
||||||
.unwrap();
|
|
||||||
// handle negative index
|
|
||||||
let is_negative = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::SLT,
|
|
||||||
raw_index,
|
|
||||||
generator.get_size_type(ctx.ctx).const_zero(),
|
|
||||||
"is_neg",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted").unwrap();
|
|
||||||
let index = ctx
|
|
||||||
.builder
|
|
||||||
.build_select(is_negative, adjusted, raw_index, "index")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
// unsigned less than is enough, because negative index after adjustment is
|
|
||||||
// bigger than the length (for unsigned cmp)
|
|
||||||
let bound_check = ctx
|
|
||||||
.builder
|
|
||||||
.build_int_compare(IntPredicate::ULT, index, len, "inbound")
|
|
||||||
.unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
bound_check,
|
|
||||||
"0:IndexError",
|
|
||||||
"index {0} out of bounds 0:{1}",
|
|
||||||
[Some(raw_index), Some(len), None],
|
|
||||||
slice.location,
|
|
||||||
);
|
|
||||||
v.data().ptr_offset(ctx, generator, &index, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -206,70 +145,20 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
target: &Expr<Option<Type>>,
|
target: &Expr<Option<Type>>,
|
||||||
value: ValueEnum<'ctx>,
|
value: ValueEnum<'ctx>,
|
||||||
|
value_ty: Type,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
// See https://docs.python.org/3/reference/simple_stmts.html#assignment-statements.
|
||||||
|
|
||||||
match &target.node {
|
match &target.node {
|
||||||
ExprKind::Tuple { elts, .. } => {
|
ExprKind::Subscript { value: target, slice: key, .. } => {
|
||||||
let BasicValueEnum::StructValue(v) =
|
// Handle "slicing" or "subscription"
|
||||||
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
generator.gen_setitem(ctx, target, key, value, value_ty)?;
|
||||||
else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
for (i, elt) in elts.iter().enumerate() {
|
|
||||||
let v = ctx
|
|
||||||
.builder
|
|
||||||
.build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem")
|
|
||||||
.unwrap();
|
|
||||||
generator.gen_assign(ctx, elt, v.into())?;
|
|
||||||
}
|
}
|
||||||
}
|
ExprKind::Tuple { elts, .. } | ExprKind::List { elts, .. } => {
|
||||||
ExprKind::Subscript { value: ls, slice, .. }
|
// Fold on `"[" [target_list] "]"` and `"(" [target_list] ")"`
|
||||||
if matches!(&slice.node, ExprKind::Slice { .. }) =>
|
generator.gen_assign_target_list(ctx, elts, value, value_ty)?;
|
||||||
{
|
|
||||||
let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
|
|
||||||
|
|
||||||
let ls = generator
|
|
||||||
.gen_expr(ctx, ls)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
|
|
||||||
.into_pointer_value();
|
|
||||||
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
|
|
||||||
let Some((start, end, step)) =
|
|
||||||
handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
|
|
||||||
else {
|
|
||||||
return Ok(());
|
|
||||||
};
|
|
||||||
let value = value
|
|
||||||
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
|
||||||
.into_pointer_value();
|
|
||||||
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
|
||||||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
|
||||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
||||||
*params.iter().next().unwrap().1
|
|
||||||
}
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let ty = ctx.get_llvm_type(generator, ty);
|
|
||||||
let Some(src_ind) = handle_slice_indices(
|
|
||||||
&None,
|
|
||||||
&None,
|
|
||||||
&None,
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
value.load_size(ctx, None),
|
|
||||||
)?
|
|
||||||
else {
|
|
||||||
return Ok(());
|
|
||||||
};
|
|
||||||
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
|
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
// Handle attribute and direct variable assignments.
|
||||||
let name = if let ExprKind::Name { id, .. } = &target.node {
|
let name = if let ExprKind::Name { id, .. } = &target.node {
|
||||||
format!("{id}.addr")
|
format!("{id}.addr")
|
||||||
} else {
|
} else {
|
||||||
|
@ -293,6 +182,233 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// See [`CodeGenerator::gen_assign_target_list`].
|
||||||
|
pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
targets: &Vec<Expr<Option<Type>>>,
|
||||||
|
value: ValueEnum<'ctx>,
|
||||||
|
value_ty: Type,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Deconstruct the tuple `value`
|
||||||
|
let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)?
|
||||||
|
else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer.
|
||||||
|
let TypeEnum::TTuple { ty: tuple_tys } = &*ctx.unifier.get_ty(value_ty) else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len());
|
||||||
|
|
||||||
|
let tuple = (0..tuple.get_type().count_fields())
|
||||||
|
.map(|i| ctx.builder.build_extract_value(tuple, i, "item").unwrap())
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
// Find the starred target if it exists.
|
||||||
|
let mut starred_target_index: Option<usize> = None; // Index of the "starred" target. If it exists, there may only be one.
|
||||||
|
for (i, target) in targets.iter().enumerate() {
|
||||||
|
if matches!(target.node, ExprKind::Starred { .. }) {
|
||||||
|
assert!(starred_target_index.is_none()); // The typechecker ensures this
|
||||||
|
starred_target_index = Some(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(starred_target_index) = starred_target_index {
|
||||||
|
assert!(tuple_tys.len() >= targets.len() - 1); // The typechecker ensures this
|
||||||
|
|
||||||
|
let a = starred_target_index; // Number of RHS values before the starred target
|
||||||
|
let b = tuple_tys.len() - (targets.len() - 1 - starred_target_index); // Number of RHS values after the starred target
|
||||||
|
// Thus `tuple[a..b]` is assigned to the starred target.
|
||||||
|
|
||||||
|
// Handle assignment before the starred target
|
||||||
|
for (target, val, val_ty) in
|
||||||
|
izip!(&targets[..starred_target_index], &tuple[..a], &tuple_tys[..a])
|
||||||
|
{
|
||||||
|
generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle assignment to the starred target
|
||||||
|
if let ExprKind::Starred { value: target, .. } = &targets[starred_target_index].node {
|
||||||
|
let vals = &tuple[a..b];
|
||||||
|
let val_tys = &tuple_tys[a..b];
|
||||||
|
|
||||||
|
// Create a sub-tuple from `value` for the starred target.
|
||||||
|
let sub_tuple_ty = ctx
|
||||||
|
.ctx
|
||||||
|
.struct_type(&vals.iter().map(BasicValueEnum::get_type).collect_vec(), false);
|
||||||
|
let psub_tuple_val =
|
||||||
|
ctx.builder.build_alloca(sub_tuple_ty, "starred_target_value_ptr").unwrap();
|
||||||
|
for (i, val) in vals.iter().enumerate() {
|
||||||
|
let pitem = ctx
|
||||||
|
.builder
|
||||||
|
.build_struct_gep(psub_tuple_val, i as u32, "starred_target_value_item")
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(pitem, *val).unwrap();
|
||||||
|
}
|
||||||
|
let sub_tuple_val =
|
||||||
|
ctx.builder.build_load(psub_tuple_val, "starred_target_value").unwrap();
|
||||||
|
|
||||||
|
// Create the typechecker type of the sub-tuple
|
||||||
|
let sub_tuple_ty = ctx.unifier.add_ty(TypeEnum::TTuple { ty: val_tys.to_vec() });
|
||||||
|
|
||||||
|
// Now assign with that sub-tuple to the starred target.
|
||||||
|
generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?;
|
||||||
|
} else {
|
||||||
|
unreachable!() // The typechecker ensures this
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle assignment after the starred target
|
||||||
|
for (target, val, val_ty) in
|
||||||
|
izip!(&targets[starred_target_index + 1..], &tuple[b..], &tuple_tys[b..])
|
||||||
|
{
|
||||||
|
generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert_eq!(tuple_tys.len(), targets.len()); // The typechecker ensures this
|
||||||
|
|
||||||
|
for (target, val, val_ty) in izip!(targets, tuple, tuple_tys) {
|
||||||
|
generator.gen_assign(ctx, target, ValueEnum::Dynamic(val), *val_ty)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// See [`CodeGenerator::gen_setitem`].
|
||||||
|
pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
target: &Expr<Option<Type>>,
|
||||||
|
key: &Expr<Option<Type>>,
|
||||||
|
value: ValueEnum<'ctx>,
|
||||||
|
value_ty: Type,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let target_ty = target.custom.unwrap();
|
||||||
|
let key_ty = key.custom.unwrap();
|
||||||
|
|
||||||
|
match &*ctx.unifier.get_ty(target_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, params: list_params, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// Handle list item assignment
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let target_item_ty = iter_type_vars(list_params).next().unwrap().ty;
|
||||||
|
|
||||||
|
let target = generator
|
||||||
|
.gen_expr(ctx, target)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, target_ty)?
|
||||||
|
.into_pointer_value();
|
||||||
|
let target = ListValue::from_ptr_val(target, llvm_usize, None);
|
||||||
|
|
||||||
|
if let ExprKind::Slice { .. } = &key.node {
|
||||||
|
// Handle assigning to a slice
|
||||||
|
let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() };
|
||||||
|
let Some((start, end, step)) = handle_slice_indices(
|
||||||
|
lower,
|
||||||
|
upper,
|
||||||
|
step,
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
target.load_size(ctx, None),
|
||||||
|
)?
|
||||||
|
else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let value =
|
||||||
|
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
|
||||||
|
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
||||||
|
|
||||||
|
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
|
||||||
|
let Some(src_ind) = handle_slice_indices(
|
||||||
|
&None,
|
||||||
|
&None,
|
||||||
|
&None,
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
value.load_size(ctx, None),
|
||||||
|
)?
|
||||||
|
else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
list_slice_assignment(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
target_item_ty,
|
||||||
|
target,
|
||||||
|
(start, end, step),
|
||||||
|
value,
|
||||||
|
src_ind,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Handle assigning to an index
|
||||||
|
let len = target.load_size(ctx, Some("len"));
|
||||||
|
|
||||||
|
let index = generator
|
||||||
|
.gen_expr(ctx, key)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, key_ty)?
|
||||||
|
.into_int_value();
|
||||||
|
let index = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// handle negative index
|
||||||
|
let is_negative = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
index,
|
||||||
|
generator.get_size_type(ctx.ctx).const_zero(),
|
||||||
|
"is_neg",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let adjusted = ctx.builder.build_int_add(index, len, "adjusted").unwrap();
|
||||||
|
let index = ctx
|
||||||
|
.builder
|
||||||
|
.build_select(is_negative, adjusted, index, "index")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// unsigned less than is enough, because negative index after adjustment is
|
||||||
|
// bigger than the length (for unsigned cmp)
|
||||||
|
let bound_check = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::ULT, index, len, "inbound")
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
bound_check,
|
||||||
|
"0:IndexError",
|
||||||
|
"index {0} out of bounds 0:{1}",
|
||||||
|
[Some(index), Some(len), None],
|
||||||
|
key.location,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Write value to index on list
|
||||||
|
let item_ptr =
|
||||||
|
target.data().ptr_offset(ctx, generator, &index, Some("list_item_ptr"));
|
||||||
|
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
||||||
|
ctx.builder.build_store(item_ptr, value).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// Handle NDArray item assignment
|
||||||
|
todo!("ndarray subscript assignment is not yet implemented");
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_for`].
|
/// See [`CodeGenerator::gen_for`].
|
||||||
pub fn gen_for<G: CodeGenerator>(
|
pub fn gen_for<G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -315,9 +431,6 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
let orelse_bb =
|
let orelse_bb =
|
||||||
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
|
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
|
||||||
|
|
||||||
// Whether the iterable is a range() expression
|
|
||||||
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
|
|
||||||
|
|
||||||
// The BB containing the increment expression
|
// The BB containing the increment expression
|
||||||
let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
|
let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
|
||||||
// The BB containing the loop condition check
|
// The BB containing the loop condition check
|
||||||
|
@ -326,17 +439,23 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
// store loop bb information and restore it later
|
// store loop bb information and restore it later
|
||||||
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
|
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
|
||||||
|
|
||||||
|
let iter_ty = iter.custom.unwrap();
|
||||||
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
||||||
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
|
v.to_basic_value_enum(ctx, generator, iter_ty)?
|
||||||
} else {
|
} else {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
};
|
};
|
||||||
if is_iterable_range_expr {
|
|
||||||
|
match &*ctx.unifier.get_ty(iter_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
||||||
// Internal variable for loop; Cannot be assigned
|
// Internal variable for loop; Cannot be assigned
|
||||||
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
||||||
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
||||||
let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))?
|
let Some(target_i) =
|
||||||
|
generator.gen_store_target(ctx, target, Some("for.target.addr"))?
|
||||||
else {
|
else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -345,8 +464,10 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
ctx.builder.build_store(i, start).unwrap();
|
ctx.builder.build_store(i, start).unwrap();
|
||||||
|
|
||||||
// Check "If step is zero, ValueError is raised."
|
// Check "If step is zero, ValueError is raised."
|
||||||
let rangenez =
|
let rangenez = ctx
|
||||||
ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap();
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
rangenez,
|
rangenez,
|
||||||
|
@ -363,7 +484,10 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
.build_conditional_branch(
|
.build_conditional_branch(
|
||||||
gen_in_range_check(
|
gen_in_range_check(
|
||||||
ctx,
|
ctx,
|
||||||
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
|
ctx.builder
|
||||||
|
.build_load(i, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap(),
|
||||||
stop,
|
stop,
|
||||||
step,
|
step,
|
||||||
),
|
),
|
||||||
|
@ -393,7 +517,10 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
generator.gen_block(ctx, body.iter())?;
|
generator.gen_block(ctx, body.iter())?;
|
||||||
} else {
|
}
|
||||||
|
TypeEnum::TObj { obj_id, params: list_params, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
|
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
|
||||||
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
|
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
|
||||||
let len = ctx
|
let len = ctx
|
||||||
|
@ -431,9 +558,14 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
.map(BasicValueEnum::into_int_value)
|
.map(BasicValueEnum::into_int_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
|
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
|
||||||
generator.gen_assign(ctx, target, val.into())?;
|
let val_ty = iter_type_vars(list_params).next().unwrap().ty;
|
||||||
|
generator.gen_assign(ctx, target, val.into(), val_ty)?;
|
||||||
generator.gen_block(ctx, body.iter())?;
|
generator.gen_block(ctx, body.iter())?;
|
||||||
}
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("unsupported for loop iterator type: {}", ctx.unifier.stringify(iter_ty));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (k, (_, _, counter)) in &var_assignment {
|
for (k, (_, _, counter)) in &var_assignment {
|
||||||
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
|
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
|
||||||
|
@ -1588,14 +1720,14 @@ pub fn gen_stmt<G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
StmtKind::AnnAssign { target, value, .. } => {
|
StmtKind::AnnAssign { target, value, .. } => {
|
||||||
if let Some(value) = value {
|
if let Some(value) = value {
|
||||||
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
|
let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) };
|
||||||
generator.gen_assign(ctx, target, value)?;
|
generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StmtKind::Assign { targets, value, .. } => {
|
StmtKind::Assign { targets, value, .. } => {
|
||||||
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
|
let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) };
|
||||||
for target in targets {
|
for target in targets {
|
||||||
generator.gen_assign(ctx, target, value.clone())?;
|
generator.gen_assign(ctx, target, value_enum.clone(), value.custom.unwrap())?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StmtKind::Continue { .. } => {
|
StmtKind::Continue { .. } => {
|
||||||
|
@ -1609,15 +1741,16 @@ pub fn gen_stmt<G: CodeGenerator>(
|
||||||
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
|
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
|
||||||
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
|
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
|
||||||
StmtKind::AugAssign { target, op, value, .. } => {
|
StmtKind::AugAssign { target, op, value, .. } => {
|
||||||
let value = gen_binop_expr(
|
let value_enum = gen_binop_expr(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
target,
|
target,
|
||||||
Binop::aug_assign(*op),
|
Binop::aug_assign(*op),
|
||||||
value,
|
value,
|
||||||
stmt.location,
|
stmt.location,
|
||||||
)?;
|
)?
|
||||||
generator.gen_assign(ctx, target, value.unwrap())?;
|
.unwrap();
|
||||||
|
generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?;
|
||||||
}
|
}
|
||||||
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
|
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
|
||||||
StmtKind::Raise { exc, .. } => {
|
StmtKind::Raise { exc, .. } => {
|
||||||
|
|
|
@ -34,13 +34,18 @@ impl<'a> Inferencer<'a> {
|
||||||
self.should_have_value(pattern)?;
|
self.should_have_value(pattern)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
ExprKind::Tuple { elts, .. } => {
|
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
|
||||||
for elt in elts {
|
for elt in elts {
|
||||||
self.check_pattern(elt, defined_identifiers)?;
|
self.check_pattern(elt, defined_identifiers)?;
|
||||||
self.should_have_value(elt)?;
|
self.should_have_value(elt)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
ExprKind::Starred { value, .. } => {
|
||||||
|
self.check_pattern(value, defined_identifiers)?;
|
||||||
|
self.should_have_value(value)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
ExprKind::Subscript { value, slice, .. } => {
|
ExprKind::Subscript { value, slice, .. } => {
|
||||||
self.check_expr(value, defined_identifiers)?;
|
self.check_expr(value, defined_identifiers)?;
|
||||||
self.should_have_value(value)?;
|
self.should_have_value(value)?;
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::convert::{From, TryInto};
|
use std::convert::{From, TryInto};
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
use std::ops::Not;
|
|
||||||
use std::{cell::RefCell, sync::Arc};
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
|
@ -19,6 +18,7 @@ use crate::{
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
TopLevelContext, TopLevelDef,
|
TopLevelContext, TopLevelDef,
|
||||||
},
|
},
|
||||||
|
typecheck::typedef::Mapping,
|
||||||
};
|
};
|
||||||
use itertools::{izip, Itertools};
|
use itertools::{izip, Itertools};
|
||||||
use nac3parser::ast::{
|
use nac3parser::ast::{
|
||||||
|
@ -100,16 +100,18 @@ pub struct Inferencer<'a> {
|
||||||
pub in_handler: bool,
|
pub in_handler: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InferenceError = HashSet<String>;
|
||||||
|
|
||||||
struct NaiveFolder();
|
struct NaiveFolder();
|
||||||
impl Fold<()> for NaiveFolder {
|
impl Fold<()> for NaiveFolder {
|
||||||
type TargetU = Option<Type>;
|
type TargetU = Option<Type>;
|
||||||
type Error = HashSet<String>;
|
type Error = InferenceError;
|
||||||
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn report_error<T>(msg: &str, location: Location) -> Result<T, HashSet<String>> {
|
fn report_error<T>(msg: &str, location: Location) -> Result<T, InferenceError> {
|
||||||
Err(HashSet::from([format!("{msg} at {location}")]))
|
Err(HashSet::from([format!("{msg} at {location}")]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,30 +119,48 @@ fn report_type_error<T>(
|
||||||
kind: TypeErrorKind,
|
kind: TypeErrorKind,
|
||||||
loc: Option<Location>,
|
loc: Option<Location>,
|
||||||
unifier: &Unifier,
|
unifier: &Unifier,
|
||||||
) -> Result<T, HashSet<String>> {
|
) -> Result<T, InferenceError> {
|
||||||
Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()]))
|
Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Traverse through a LHS expression in an assignment and set [`ExprContext`] to [`ExprContext::Store`]
|
||||||
|
/// when appropriate.
|
||||||
|
///
|
||||||
|
/// nac3parser's `ExprContext` output is generally incorrect, and requires manual fixes.
|
||||||
|
fn fix_assignment_target_context(node: &mut ast::Located<ExprKind>) {
|
||||||
|
match &mut node.node {
|
||||||
|
ExprKind::Name { ctx, .. }
|
||||||
|
| ExprKind::Attribute { ctx, .. }
|
||||||
|
| ExprKind::Subscript { ctx, .. } => {
|
||||||
|
*ctx = ExprContext::Store;
|
||||||
|
}
|
||||||
|
ExprKind::Starred { ctx, value } => {
|
||||||
|
*ctx = ExprContext::Store;
|
||||||
|
fix_assignment_target_context(value);
|
||||||
|
}
|
||||||
|
ExprKind::Tuple { ctx, elts } | ExprKind::List { ctx, elts } => {
|
||||||
|
*ctx = ExprContext::Store;
|
||||||
|
elts.iter_mut().for_each(fix_assignment_target_context);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> Fold<()> for Inferencer<'a> {
|
impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
type TargetU = Option<Type>;
|
type TargetU = Option<Type>;
|
||||||
type Error = HashSet<String>;
|
type Error = InferenceError;
|
||||||
|
|
||||||
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fold_stmt(
|
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
||||||
&mut self,
|
|
||||||
mut node: ast::Stmt<()>,
|
|
||||||
) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
|
||||||
let stmt = match node.node {
|
let stmt = match node.node {
|
||||||
// we don't want fold over type annotation
|
// we don't want fold over type annotation
|
||||||
ast::StmtKind::AnnAssign { mut target, annotation, value, simple, config_comment } => {
|
ast::StmtKind::AnnAssign { mut target, annotation, value, simple, config_comment } => {
|
||||||
|
fix_assignment_target_context(&mut target); // Fix parser bug
|
||||||
|
|
||||||
self.infer_pattern(&target)?;
|
self.infer_pattern(&target)?;
|
||||||
// fix parser problem...
|
|
||||||
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
|
|
||||||
*ctx = ExprContext::Store;
|
|
||||||
}
|
|
||||||
|
|
||||||
let target = Box::new(self.fold_expr(*target)?);
|
let target = Box::new(self.fold_expr(*target)?);
|
||||||
let value = if let Some(v) = value {
|
let value = if let Some(v) = value {
|
||||||
|
@ -302,69 +322,41 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
custom: None,
|
custom: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => {
|
ast::StmtKind::Assign { mut targets, type_comment, config_comment, value, .. } => {
|
||||||
for target in &mut *targets {
|
// Fix parser bug
|
||||||
if let ExprKind::Attribute { ctx, .. } = &mut target.node {
|
targets.iter_mut().for_each(fix_assignment_target_context);
|
||||||
*ctx = ExprContext::Store;
|
|
||||||
}
|
// NOTE: Do not register identifiers into `self.defined_identifiers` before checking targets
|
||||||
}
|
// and value, otherwise the Inferencer might use undefined variables in `self.defined_identifiers`
|
||||||
if targets.iter().all(|t| matches!(t.node, ExprKind::Name { .. })) {
|
// and produce strange errors.
|
||||||
let ast::StmtKind::Assign { targets, value, .. } = node.node else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
let value = self.fold_expr(*value)?;
|
let value = self.fold_expr(*value)?;
|
||||||
let value_ty = value.custom.unwrap();
|
|
||||||
let targets: Result<Vec<_>, _> = targets
|
|
||||||
.into_iter()
|
|
||||||
.map(|target| {
|
|
||||||
let ExprKind::Name { id, ctx } = target.node else { unreachable!() };
|
|
||||||
|
|
||||||
self.defined_identifiers.insert(id);
|
let targets: Vec<_> = targets
|
||||||
let target_ty = if let Some(ty) = self.variable_mapping.get(&id) {
|
.into_iter()
|
||||||
*ty
|
.map(|target| -> Result<_, InferenceError> {
|
||||||
} else {
|
// In cases like `x = y = z = rhs`, `rhs`'s type will be constrained by
|
||||||
let unifier: &mut Unifier = self.unifier;
|
// the intersection of `x`, `y`, and `z` here.
|
||||||
self.function_data
|
let target = self.fold_assign_target(target, value.custom.unwrap())?;
|
||||||
.resolver
|
Ok(target)
|
||||||
.get_symbol_type(
|
|
||||||
unifier,
|
|
||||||
&self.top_level.definitions.read(),
|
|
||||||
self.primitives,
|
|
||||||
id,
|
|
||||||
)
|
|
||||||
.unwrap_or_else(|_| {
|
|
||||||
self.variable_mapping.insert(id, value_ty);
|
|
||||||
value_ty
|
|
||||||
})
|
})
|
||||||
};
|
.try_collect()?;
|
||||||
let location = target.location;
|
|
||||||
self.unifier.unify(value_ty, target_ty).map(|()| Located {
|
// Do this only after folding targets and value
|
||||||
location,
|
for target in &targets {
|
||||||
node: ExprKind::Name { id, ctx },
|
self.infer_pattern(target)?;
|
||||||
custom: Some(target_ty),
|
}
|
||||||
})
|
|
||||||
})
|
Located {
|
||||||
.collect();
|
|
||||||
let loc = node.location;
|
|
||||||
let targets = targets.map_err(|e| {
|
|
||||||
HashSet::from([e.at(Some(loc)).to_display(self.unifier).to_string()])
|
|
||||||
})?;
|
|
||||||
return Ok(Located {
|
|
||||||
location: node.location,
|
location: node.location,
|
||||||
node: ast::StmtKind::Assign {
|
node: ast::StmtKind::Assign {
|
||||||
targets,
|
targets,
|
||||||
|
type_comment,
|
||||||
|
config_comment,
|
||||||
value: Box::new(value),
|
value: Box::new(value),
|
||||||
type_comment: None,
|
|
||||||
config_comment: config_comment.clone(),
|
|
||||||
},
|
},
|
||||||
custom: None,
|
custom: None,
|
||||||
});
|
|
||||||
}
|
}
|
||||||
for target in targets {
|
|
||||||
self.infer_pattern(target)?;
|
|
||||||
}
|
|
||||||
fold::fold_stmt(self, node)?
|
|
||||||
}
|
}
|
||||||
ast::StmtKind::With { ref items, .. } => {
|
ast::StmtKind::With { ref items, .. } => {
|
||||||
for item in items {
|
for item in items {
|
||||||
|
@ -377,7 +369,8 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
_ => fold::fold_stmt(self, node)?,
|
_ => fold::fold_stmt(self, node)?,
|
||||||
};
|
};
|
||||||
match &stmt.node {
|
match &stmt.node {
|
||||||
ast::StmtKind::AnnAssign { .. }
|
ast::StmtKind::Assign { .. }
|
||||||
|
| ast::StmtKind::AnnAssign { .. }
|
||||||
| ast::StmtKind::Break { .. }
|
| ast::StmtKind::Break { .. }
|
||||||
| ast::StmtKind::Continue { .. }
|
| ast::StmtKind::Continue { .. }
|
||||||
| ast::StmtKind::Expr { .. }
|
| ast::StmtKind::Expr { .. }
|
||||||
|
@ -387,11 +380,6 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
|
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
|
||||||
self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?;
|
self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?;
|
||||||
}
|
}
|
||||||
ast::StmtKind::Assign { targets, value, .. } => {
|
|
||||||
for target in targets {
|
|
||||||
self.unify(target.custom.unwrap(), value.custom.unwrap(), &target.location)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ast::StmtKind::Raise { exc, cause, .. } => {
|
ast::StmtKind::Raise { exc, cause, .. } => {
|
||||||
if let Some(cause) = cause {
|
if let Some(cause) = cause {
|
||||||
return report_error("raise ... from cause is not supported", cause.location);
|
return report_error("raise ... from cause is not supported", cause.location);
|
||||||
|
@ -531,6 +519,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
_ => fold::fold_expr(self, node)?,
|
_ => fold::fold_expr(self, node)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let custom = match &expr.node {
|
let custom = match &expr.node {
|
||||||
ExprKind::Constant { value, .. } => Some(self.infer_constant(value, &expr.location)?),
|
ExprKind::Constant { value, .. } => Some(self.infer_constant(value, &expr.location)?),
|
||||||
ExprKind::Name { id, .. } => {
|
ExprKind::Name { id, .. } => {
|
||||||
|
@ -578,8 +567,6 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
Some(self.infer_identifier(*id)?)
|
Some(self.infer_identifier(*id)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
|
||||||
ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
|
||||||
ExprKind::Attribute { value, attr, ctx } => {
|
ExprKind::Attribute { value, attr, ctx } => {
|
||||||
Some(self.infer_attribute(value, *attr, *ctx)?)
|
Some(self.infer_attribute(value, *attr, *ctx)?)
|
||||||
}
|
}
|
||||||
|
@ -593,8 +580,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
ExprKind::Compare { left, ops, comparators } => {
|
ExprKind::Compare { left, ops, comparators } => {
|
||||||
Some(self.infer_compare(expr.location, left, ops, comparators)?)
|
Some(self.infer_compare(expr.location, left, ops, comparators)?)
|
||||||
}
|
}
|
||||||
ExprKind::Subscript { value, slice, ctx, .. } => {
|
ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||||
Some(self.infer_subscript(value.as_ref(), slice.as_ref(), *ctx)?)
|
ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||||
|
ExprKind::Subscript { value, slice, .. } => {
|
||||||
|
Some(self.infer_getitem(value.as_ref(), slice.as_ref())?)
|
||||||
}
|
}
|
||||||
ExprKind::IfExp { test, body, orelse } => {
|
ExprKind::IfExp { test, body, orelse } => {
|
||||||
Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?)
|
Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?)
|
||||||
|
@ -612,22 +601,22 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type InferenceResult = Result<Type, HashSet<String>>;
|
type InferenceResult = Result<Type, InferenceError>;
|
||||||
|
|
||||||
impl<'a> Inferencer<'a> {
|
impl<'a> Inferencer<'a> {
|
||||||
/// Constrain a <: b
|
/// Constrain a <: b
|
||||||
/// Currently implemented as unification
|
/// Currently implemented as unification
|
||||||
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet<String>> {
|
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
|
||||||
self.unify(a, b, location)
|
self.unify(a, b, location)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet<String>> {
|
fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
|
||||||
self.unifier.unify(a, b).map_err(|e| {
|
self.unifier.unify(a, b).map_err(|e| {
|
||||||
HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()])
|
HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), HashSet<String>> {
|
fn infer_pattern<T>(&mut self, pattern: &ast::Expr<T>) -> Result<(), InferenceError> {
|
||||||
match &pattern.node {
|
match &pattern.node {
|
||||||
ExprKind::Name { id, .. } => {
|
ExprKind::Name { id, .. } => {
|
||||||
if !self.defined_identifiers.contains(id) {
|
if !self.defined_identifiers.contains(id) {
|
||||||
|
@ -641,6 +630,13 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
ExprKind::List { elts, .. } => {
|
||||||
|
for elt in elts {
|
||||||
|
self.infer_pattern(elt)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
ExprKind::Starred { value, .. } => self.infer_pattern(value),
|
||||||
_ => Ok(()),
|
_ => Ok(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -716,7 +712,7 @@ impl<'a> Inferencer<'a> {
|
||||||
location: Location,
|
location: Location,
|
||||||
args: Arguments,
|
args: Arguments,
|
||||||
body: ast::Expr<()>,
|
body: ast::Expr<()>,
|
||||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
) -> Result<ast::Expr<Option<Type>>, InferenceError> {
|
||||||
if !args.posonlyargs.is_empty()
|
if !args.posonlyargs.is_empty()
|
||||||
|| args.vararg.is_some()
|
|| args.vararg.is_some()
|
||||||
|| !args.kwonlyargs.is_empty()
|
|| !args.kwonlyargs.is_empty()
|
||||||
|
@ -787,7 +783,7 @@ impl<'a> Inferencer<'a> {
|
||||||
location: Location,
|
location: Location,
|
||||||
elt: ast::Expr<()>,
|
elt: ast::Expr<()>,
|
||||||
mut generators: Vec<Comprehension>,
|
mut generators: Vec<Comprehension>,
|
||||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
) -> Result<ast::Expr<Option<Type>>, InferenceError> {
|
||||||
if generators.len() != 1 {
|
if generators.len() != 1 {
|
||||||
return report_error(
|
return report_error(
|
||||||
"Only 1 generator statement for list comprehension is supported",
|
"Only 1 generator statement for list comprehension is supported",
|
||||||
|
@ -893,7 +889,7 @@ impl<'a> Inferencer<'a> {
|
||||||
id: StrRef,
|
id: StrRef,
|
||||||
arg_index: usize,
|
arg_index: usize,
|
||||||
shape_expr: Located<ExprKind>,
|
shape_expr: Located<ExprKind>,
|
||||||
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> {
|
) -> Result<(u64, ast::Expr<Option<Type>>), InferenceError> {
|
||||||
/*
|
/*
|
||||||
### Further explanation
|
### Further explanation
|
||||||
|
|
||||||
|
@ -1030,7 +1026,7 @@ impl<'a> Inferencer<'a> {
|
||||||
func: &ast::Expr<()>,
|
func: &ast::Expr<()>,
|
||||||
args: &mut Vec<ast::Expr<()>>,
|
args: &mut Vec<ast::Expr<()>>,
|
||||||
keywords: &[Located<ast::KeywordData>],
|
keywords: &[Located<ast::KeywordData>],
|
||||||
) -> Result<Option<ast::Expr<Option<Type>>>, HashSet<String>> {
|
) -> Result<Option<ast::Expr<Option<Type>>>, InferenceError> {
|
||||||
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
|
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
@ -1588,7 +1584,7 @@ impl<'a> Inferencer<'a> {
|
||||||
func: ast::Expr<()>,
|
func: ast::Expr<()>,
|
||||||
mut args: Vec<ast::Expr<()>>,
|
mut args: Vec<ast::Expr<()>>,
|
||||||
keywords: Vec<Located<ast::KeywordData>>,
|
keywords: Vec<Located<ast::KeywordData>>,
|
||||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
) -> Result<ast::Expr<Option<Type>>, InferenceError> {
|
||||||
if let Some(spec_call_func) =
|
if let Some(spec_call_func) =
|
||||||
self.try_fold_special_call(location, &func, &mut args, &keywords)?
|
self.try_fold_special_call(location, &func, &mut args, &keywords)?
|
||||||
{
|
{
|
||||||
|
@ -1941,28 +1937,270 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(res.unwrap())
|
Ok(res.unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infers the type of a subscript expression on an `ndarray`.
|
/// Fold an assignment `"target_list"` recursively, and check RHS's type.
|
||||||
fn infer_subscript_ndarray(
|
/// See definition of `"target_list"` in <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
|
||||||
|
fn fold_assign_target_list(
|
||||||
&mut self,
|
&mut self,
|
||||||
value: &ast::Expr<Option<Type>>,
|
target_list_location: &Location,
|
||||||
slice: &ast::Expr<Option<Type>>,
|
mut targets: Vec<ast::Expr<()>>,
|
||||||
dummy_tvar: Type,
|
rhs_ty: Type,
|
||||||
ndims: Type,
|
) -> Result<Vec<ast::Expr<Option<Type>>>, InferenceError> {
|
||||||
) -> InferenceResult {
|
// TODO: Allow bidirectional typechecking? Currently RHS's type has to be resolved.
|
||||||
debug_assert!(matches!(
|
let TypeEnum::TTuple { ty: rhs_tys } = &*self.unifier.get_ty(rhs_ty) else {
|
||||||
&*self.unifier.get_ty_immutable(dummy_tvar),
|
// TODO: Allow RHS AST-aware error reporting
|
||||||
TypeEnum::TVar { is_const_generic: false, .. }
|
return report_error(
|
||||||
));
|
"LHS target list pattern requires RHS to be a tuple type",
|
||||||
|
*target_list_location,
|
||||||
let constrained_ty =
|
);
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims));
|
|
||||||
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
|
||||||
|
|
||||||
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
|
|
||||||
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndims = values
|
// Find the starred target if it exists.
|
||||||
|
let mut starred_target_index: Option<usize> = None; // Index of the "starred" target. If it exists, there may only be one.
|
||||||
|
for (i, target) in targets.iter().enumerate() {
|
||||||
|
if matches!(target.node, ExprKind::Starred { .. }) {
|
||||||
|
if starred_target_index.is_none() {
|
||||||
|
// First "starred" target found.
|
||||||
|
starred_target_index = Some(i);
|
||||||
|
} else {
|
||||||
|
// Second "starred" targets found. This is an error.
|
||||||
|
return report_error(
|
||||||
|
"there can only be one starred target, but found another one",
|
||||||
|
target.location,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut folded_targets: Vec<ast::Expr<Option<Type>>> = Vec::new();
|
||||||
|
if let Some(starred_target_index) = starred_target_index {
|
||||||
|
if rhs_tys.len() < targets.len() - 1 {
|
||||||
|
/*
|
||||||
|
Rules:
|
||||||
|
```
|
||||||
|
(x, *ys, z) = (1,) # error
|
||||||
|
(x, *ys, z) = (1, 2) # ok, ys = ()
|
||||||
|
(x, *ys, z) = (1, 2, 3) # ok, ys = (2,)
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
return report_error(
|
||||||
|
&format!(
|
||||||
|
"Target list pattern requires RHS tuple type have to at least {} element(s), but RHS only has {} element(s)",
|
||||||
|
targets.len() - 1,
|
||||||
|
rhs_tys.len()
|
||||||
|
),
|
||||||
|
*target_list_location
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
(a, b, c, ..., *xs, ..., x, y, z)
|
||||||
|
before ^^^^^^^^^^^^ ^^^ ^^^^^^^^^^^^ after
|
||||||
|
starred
|
||||||
|
*/
|
||||||
|
|
||||||
|
let targets_after = targets.drain(starred_target_index + 1..).collect_vec();
|
||||||
|
let target_starred = targets.pop().unwrap();
|
||||||
|
let targets_before = targets;
|
||||||
|
|
||||||
|
let a = targets_before.len();
|
||||||
|
let b = rhs_tys.len() - targets_after.len();
|
||||||
|
|
||||||
|
let rhs_tys_before = &rhs_tys[..a];
|
||||||
|
let rhs_tys_starred = &rhs_tys[a..b];
|
||||||
|
let rhs_tys_after = &rhs_tys[b..];
|
||||||
|
|
||||||
|
// Fold before the starred target
|
||||||
|
for (target, rhs_ty) in izip!(targets_before, rhs_tys_before) {
|
||||||
|
folded_targets.push(self.fold_assign_target(target, *rhs_ty)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fold the starred target
|
||||||
|
if let ExprKind::Starred { value: target, .. } = target_starred.node {
|
||||||
|
let ty = self.unifier.add_ty(TypeEnum::TTuple { ty: rhs_tys_starred.to_vec() });
|
||||||
|
let folded_target = self.fold_assign_target(*target, ty)?;
|
||||||
|
folded_targets.push(Located {
|
||||||
|
location: target_starred.location,
|
||||||
|
node: ExprKind::Starred {
|
||||||
|
value: Box::new(folded_target),
|
||||||
|
ctx: ExprContext::Store,
|
||||||
|
},
|
||||||
|
custom: None,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fold after the starred target
|
||||||
|
for (target, rhs_ty) in izip!(targets_after, rhs_tys_after) {
|
||||||
|
folded_targets.push(self.fold_assign_target(target, *rhs_ty)?);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fold target list without a "starred" target.
|
||||||
|
if rhs_tys.len() != targets.len() {
|
||||||
|
return report_error(
|
||||||
|
&format!(
|
||||||
|
"Target list pattern requires RHS tuple type have to {} element(s), but RHS only has {} element(s)",
|
||||||
|
targets.len() - 1,
|
||||||
|
rhs_tys.len()
|
||||||
|
),
|
||||||
|
*target_list_location
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (target, rhs_ty) in izip!(targets, rhs_tys) {
|
||||||
|
folded_targets.push(self.fold_assign_target(target, *rhs_ty)?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(folded_targets)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fold an assignment "target" recursively, and check RHS's type.
|
||||||
|
/// See definition of "target" in <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
|
||||||
|
fn fold_assign_target(
|
||||||
|
&mut self,
|
||||||
|
target: ast::Expr<()>,
|
||||||
|
rhs_ty: Type,
|
||||||
|
) -> Result<ast::Expr<Option<Type>>, InferenceError> {
|
||||||
|
match target.node {
|
||||||
|
ExprKind::Name { id, .. } => {
|
||||||
|
// Fold on "identifier"
|
||||||
|
match self.variable_mapping.get(&id) {
|
||||||
|
None => {
|
||||||
|
// Assigning to a new variable name; RHS's type could be anything.
|
||||||
|
let expected_rhs_ty = self
|
||||||
|
.unifier
|
||||||
|
.get_fresh_var(
|
||||||
|
Some(format!("type_of_{id}").into()),
|
||||||
|
Some(target.location),
|
||||||
|
)
|
||||||
|
.ty;
|
||||||
|
self.variable_mapping.insert(id, expected_rhs_ty); // Register new variable
|
||||||
|
self.constrain(rhs_ty, expected_rhs_ty, &target.location)?;
|
||||||
|
}
|
||||||
|
Some(expected_rhs_ty) => {
|
||||||
|
// Re-assigning to an existing variable name.
|
||||||
|
self.constrain(rhs_ty, *expected_rhs_ty, &target.location)?;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Located {
|
||||||
|
location: target.location,
|
||||||
|
node: ExprKind::Name { id, ctx: ExprContext::Store },
|
||||||
|
custom: Some(rhs_ty), // Type info is needed here because of the CodeGenerator.
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ExprKind::Attribute { .. } => {
|
||||||
|
// Fold on "attributeref"
|
||||||
|
let pattern = self.fold_expr(target)?;
|
||||||
|
let expected_rhs_ty = pattern.custom.unwrap();
|
||||||
|
self.constrain(rhs_ty, expected_rhs_ty, &pattern.location)?;
|
||||||
|
Ok(pattern)
|
||||||
|
}
|
||||||
|
ExprKind::Subscript { value: target, slice: key, .. } => {
|
||||||
|
// Fold on "slicing" or "subscription"
|
||||||
|
// TODO: Make `__setitem__` a general object field like `__add__` in NAC3?
|
||||||
|
let target = self.fold_expr(*target)?;
|
||||||
|
let key = self.fold_expr(*key)?;
|
||||||
|
|
||||||
|
let expected_rhs_ty = self.infer_setitem_value_type(&target, &key)?;
|
||||||
|
self.constrain(rhs_ty, expected_rhs_ty, &target.location)?;
|
||||||
|
|
||||||
|
Ok(Located {
|
||||||
|
location: target.location,
|
||||||
|
node: ExprKind::Subscript {
|
||||||
|
value: Box::new(target),
|
||||||
|
slice: Box::new(key),
|
||||||
|
ctx: ExprContext::Store,
|
||||||
|
},
|
||||||
|
custom: None, // We don't need to know the type of `target[key]`
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ExprKind::List { elts, .. } => {
|
||||||
|
// Fold on `"[" [target_list] "]"`
|
||||||
|
let elts = self.fold_assign_target_list(&target.location, elts, rhs_ty)?;
|
||||||
|
Ok(Located {
|
||||||
|
location: target.location,
|
||||||
|
node: ExprKind::List { ctx: ExprContext::Store, elts },
|
||||||
|
custom: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
// Fold on `"(" [target_list] ")"`
|
||||||
|
let elts = self.fold_assign_target_list(&target.location, elts, rhs_ty)?;
|
||||||
|
Ok(Located {
|
||||||
|
location: target.location,
|
||||||
|
node: ExprKind::Tuple { ctx: ExprContext::Store, elts },
|
||||||
|
custom: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ExprKind::Starred { .. } => report_error(
|
||||||
|
"starred assignment target must be in a list or tuple",
|
||||||
|
target.location,
|
||||||
|
),
|
||||||
|
_ => report_error("encountered unsupported/illegal LHS pattern", target.location),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Typecheck the subscript slice indexing into an ndarray.
|
||||||
|
///
|
||||||
|
/// That is:
|
||||||
|
/// ```python
|
||||||
|
/// my_ndarray[::-2, 1, :, None, 9:23]
|
||||||
|
/// ^^^^^^^^^^^^^^^^^^^^^^ this
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// The number of dimensions to subtract from the ndarray being indexed is also calculated and returned,
|
||||||
|
/// it could even be negative when more axes are added because of `None`.
|
||||||
|
fn fold_ndarray_subscript_slice(
|
||||||
|
&mut self,
|
||||||
|
slice: &ast::Expr<Option<Type>>,
|
||||||
|
) -> Result<i128, InferenceError> {
|
||||||
|
// TODO: Handle `None` / `np.newaxis`
|
||||||
|
|
||||||
|
// Flatten `slice` into subscript indices.
|
||||||
|
let indices = match &slice.node {
|
||||||
|
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
|
||||||
|
_ => vec![slice],
|
||||||
|
};
|
||||||
|
|
||||||
|
// Typecheck the subscript indices.
|
||||||
|
// We will also take the opportunity to deduce `dims_to_subtract` as well
|
||||||
|
let mut dims_to_subtract: i128 = 0;
|
||||||
|
for index in indices {
|
||||||
|
if let ExprKind::Slice { lower, upper, step } = &index.node {
|
||||||
|
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||||
|
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Treat anything else as an integer index, and force unify their type to int32.
|
||||||
|
self.unify(index.custom.unwrap(), self.primitives.int32, &index.location)?;
|
||||||
|
dims_to_subtract += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(dims_to_subtract)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the `ndims` [`Type`] of an ndarray is valid (e.g., no negative values),
|
||||||
|
/// and attempt to subtract `ndims` by `dims_to_subtract` and return subtracted `ndims`.
|
||||||
|
///
|
||||||
|
/// `dims_to_subtract` can be set to `0` if you only want to check if `ndims` is valid.
|
||||||
|
fn check_ndarray_ndims_and_subtract(
|
||||||
|
&mut self,
|
||||||
|
target_ty: Type,
|
||||||
|
ndims: Type,
|
||||||
|
dims_to_subtract: i128,
|
||||||
|
) -> Result<Type, InferenceError> {
|
||||||
|
// Typecheck `ndims`.
|
||||||
|
let TypeEnum::TLiteral { values: ndims, .. } = &*self.unifier.get_ty_immutable(ndims)
|
||||||
|
else {
|
||||||
|
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
|
||||||
|
};
|
||||||
|
assert!(!ndims.is_empty());
|
||||||
|
|
||||||
|
// Check if there are negative literals.
|
||||||
|
// NOTE: Don't mix this with subtracting dims, otherwise the user errors could be confusing.
|
||||||
|
let ndims = ndims
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
|
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
@ -1973,204 +2211,229 @@ impl<'a> Inferencer<'a> {
|
||||||
)])
|
)])
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
assert!(!ndims.is_empty());
|
// Infer the new `ndims` after indexing the ndarray with `slice`.
|
||||||
|
// Disallow subscripting if any Literal value will subscript on an element.
|
||||||
// The number of dimensions subscripted by the index expression.
|
|
||||||
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
|
|
||||||
// dimension will remove a dimension.
|
|
||||||
let subscripted_dims = match &slice.node {
|
|
||||||
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
|
|
||||||
if let ExprKind::Slice { .. } = &value_subexpr.node {
|
|
||||||
acc
|
|
||||||
} else {
|
|
||||||
acc + 1
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
|
|
||||||
ExprKind::Slice { .. } => 0,
|
|
||||||
_ => 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
|
|
||||||
// ndarray[T, Literal[1]] - Non-Slice index always returns an object of type T
|
|
||||||
|
|
||||||
assert_ne!(ndims[0], 0);
|
|
||||||
|
|
||||||
Ok(dummy_tvar)
|
|
||||||
} else {
|
|
||||||
// Otherwise - Index returns an object of type ndarray[T, Literal[N - subscripted_dims]]
|
|
||||||
|
|
||||||
// Disallow subscripting if any Literal value will subscript on an element
|
|
||||||
let new_ndims = ndims
|
let new_ndims = ndims
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|v| {
|
.map(|v| {
|
||||||
let v = i128::from(v) - i128::from(subscripted_dims);
|
let v = i128::from(v) - dims_to_subtract;
|
||||||
u64::try_from(v)
|
u64::try_from(v)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.map_err(|_| {
|
.map_err(|_| {
|
||||||
HashSet::from([format!(
|
HashSet::from([format!(
|
||||||
"Cannot subscript {} by {subscripted_dims} dimensions",
|
"Cannot subscript {} by {dims_to_subtract} dimension(s)",
|
||||||
self.unifier.stringify(value.custom.unwrap()),
|
self.unifier.stringify(target_ty),
|
||||||
)])
|
)])
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if new_ndims.iter().any(|v| *v == 0) {
|
let new_ndims_ty = self
|
||||||
|
.unifier
|
||||||
|
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
|
||||||
|
|
||||||
|
Ok(new_ndims_ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Infer the type of the result of indexing into an ndarray.
|
||||||
|
///
|
||||||
|
/// * `ndarray_ty` - The [`Type`] of the ndarray being indexed into.
|
||||||
|
/// * `slice` - The subscript expression indexing into the ndarray.
|
||||||
|
fn infer_ndarray_subscript(
|
||||||
|
&mut self,
|
||||||
|
ndarray_ty: Type,
|
||||||
|
slice: &ast::Expr<Option<Type>>,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let (dtype, ndims) = unpack_ndarray_var_tys(self.unifier, ndarray_ty);
|
||||||
|
|
||||||
|
let dims_to_substract = self.fold_ndarray_subscript_slice(slice)?;
|
||||||
|
let new_ndims =
|
||||||
|
self.check_ndarray_ndims_and_subtract(ndarray_ty, ndims, dims_to_substract)?;
|
||||||
|
|
||||||
|
// Now we need extra work to check `new_ndims` to see if the user has indexed into a single element.
|
||||||
|
|
||||||
|
let TypeEnum::TLiteral { values: new_ndims_values, .. } = &*self.unifier.get_ty(new_ndims)
|
||||||
|
else {
|
||||||
|
unreachable!("infer_ndarray_ndims should always return TLiteral")
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_ndims_values = new_ndims_values
|
||||||
|
.iter()
|
||||||
|
.map(|v| u64::try_from(v.clone()).expect("new_ndims should be convertible to u64"))
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
if new_ndims_values.len() == 1 && new_ndims_values[0] == 0 {
|
||||||
|
// The subscripted ndarray must be unsized
|
||||||
|
// The user must be indexing into a single element
|
||||||
|
Ok(dtype)
|
||||||
|
} else {
|
||||||
|
// The subscripted ndarray is not unsized / may not be unsized. (i.e., may or may not have indexed into a single element)
|
||||||
|
|
||||||
|
if new_ndims_values.iter().any(|v| *v == 0) {
|
||||||
|
// TODO: Difficult to implement since now the return may both be a scalar type, or an ndarray type.
|
||||||
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
|
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndims_ty = self
|
let new_ndarray_ty =
|
||||||
.unifier
|
make_ndarray_ty(self.unifier, self.primitives, Some(dtype), Some(new_ndims));
|
||||||
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
|
Ok(new_ndarray_ty)
|
||||||
let subscripted_ty =
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty));
|
|
||||||
|
|
||||||
Ok(subscripted_ty)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_subscript(
|
/// Infer the type of the result of indexing into a list.
|
||||||
|
///
|
||||||
|
/// * `list_ty` - The [`Type`] of the list being indexed into.
|
||||||
|
/// * `key` - The subscript expression indexing into the list.
|
||||||
|
fn infer_list_subscript(
|
||||||
&mut self,
|
&mut self,
|
||||||
value: &ast::Expr<Option<Type>>,
|
list_ty: Type,
|
||||||
slice: &ast::Expr<Option<Type>>,
|
key: &ast::Expr<Option<Type>>,
|
||||||
ctx: ExprContext,
|
) -> Result<Type, InferenceError> {
|
||||||
) -> InferenceResult {
|
let TypeEnum::TObj { params: list_params, .. } = &*self.unifier.get_ty(list_ty) else {
|
||||||
let report_unscriptable_error = |unifier: &mut Unifier| {
|
unreachable!()
|
||||||
// User is attempting to index into a value of an unsupported type.
|
|
||||||
|
|
||||||
let value_ty = value.custom.unwrap();
|
|
||||||
let value_ty_str = unifier.stringify(value_ty);
|
|
||||||
|
|
||||||
return report_error(
|
|
||||||
format!("'{value_ty_str}' object is not subscriptable").as_str(),
|
|
||||||
slice.location, // using the slice's location (rather than value's) because it is more clear
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
let item_ty = iter_type_vars(list_params).nth(0).unwrap().ty;
|
||||||
|
|
||||||
let ty = self.unifier.get_dummy_var().ty;
|
if let ExprKind::Slice { lower, upper, step } = &key.node {
|
||||||
match &slice.node {
|
// Typecheck on the slice
|
||||||
ExprKind::Slice { lower, upper, step } => {
|
|
||||||
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||||
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
let v_ty = v.custom.unwrap();
|
||||||
}
|
self.constrain(v_ty, self.primitives.int32, &v.location)?;
|
||||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
|
||||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
|
||||||
let list_tvar = iter_type_vars(params).nth(0).unwrap();
|
|
||||||
self.unifier
|
|
||||||
.subst(
|
|
||||||
self.primitives.list,
|
|
||||||
&into_var_map([TypeVar { id: list_tvar.id, ty }]),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
||||||
let (_, ndims) =
|
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => {
|
|
||||||
return report_unscriptable_error(self.unifier);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
|
||||||
Ok(list_like_ty)
|
|
||||||
}
|
|
||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
|
||||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
||||||
let (_, ndims) =
|
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
|
||||||
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// the index is a constant, so value can be a sequence.
|
|
||||||
let ind: Option<i32> = (*val).try_into().ok();
|
|
||||||
let ind =
|
|
||||||
ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
|
|
||||||
let map = once((
|
|
||||||
ind.into(),
|
|
||||||
RecordField::new(ty, ctx == ExprContext::Store, Some(value.location)),
|
|
||||||
))
|
|
||||||
.collect();
|
|
||||||
let seq = self.unifier.add_record(map);
|
|
||||||
self.constrain(value.custom.unwrap(), seq, &value.location)?;
|
|
||||||
Ok(ty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExprKind::Tuple { elts, .. } => {
|
|
||||||
if value
|
|
||||||
.custom
|
|
||||||
.unwrap()
|
|
||||||
.obj_id(self.unifier)
|
|
||||||
.is_some_and(|id| id == PrimDef::NDArray.id())
|
|
||||||
.not()
|
|
||||||
{
|
|
||||||
return report_error(
|
|
||||||
"Tuple slices are only supported for ndarrays",
|
|
||||||
slice.location,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
for elt in elts {
|
|
||||||
if let ExprKind::Slice { lower, upper, step } = &elt.node {
|
|
||||||
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
|
||||||
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
|
||||||
}
|
}
|
||||||
|
Ok(list_ty) // type list[T]
|
||||||
} else {
|
} else {
|
||||||
self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?;
|
// Treat anything else as an integer index, and force unify their type to int32.
|
||||||
|
self.constrain(key.custom.unwrap(), self.primitives.int32, &key.location)?;
|
||||||
|
Ok(item_ty) // type T
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
/// Generate a type that constrains the type of `target` to have a `__getitem__` at `index`.
|
||||||
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
///
|
||||||
|
/// * `target` - The target being indexed by `index`.
|
||||||
|
/// * `index` - The constant index.
|
||||||
|
/// * `mutable` - Should the constraint be mutable or immutable?
|
||||||
|
fn get_constant_index_item_type(
|
||||||
|
&mut self,
|
||||||
|
target: &ast::Expr<Option<Type>>,
|
||||||
|
index: i128,
|
||||||
|
mutable: bool,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let Ok(index) = i32::try_from(index) else {
|
||||||
|
return Err(HashSet::from(["Index must be int32".to_string()]));
|
||||||
|
};
|
||||||
|
|
||||||
|
let item_ty = self.unifier.get_dummy_var().ty; // To be resolved by the unifier
|
||||||
|
|
||||||
|
// Constrain `target`
|
||||||
|
let fields_constrain = Mapping::from_iter([(
|
||||||
|
RecordKey::Int(index),
|
||||||
|
RecordField::new(item_ty, mutable, Some(target.location)),
|
||||||
|
)]);
|
||||||
|
let fields_constrain_ty = self.unifier.add_record(fields_constrain);
|
||||||
|
self.constrain(target.custom.unwrap(), fields_constrain_ty, &target.location)?;
|
||||||
|
|
||||||
|
Ok(item_ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Infer the return type of a `__getitem__` expression.
|
||||||
|
///
|
||||||
|
/// i.e., `target[key]`, where the [`ExprContext`] is [`ExprContext::Load`].
|
||||||
|
fn infer_getitem(
|
||||||
|
&mut self,
|
||||||
|
target: &ast::Expr<Option<Type>>,
|
||||||
|
key: &ast::Expr<Option<Type>>,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let target_ty = target.custom.unwrap();
|
||||||
|
|
||||||
|
match &*self.unifier.get_ty(target_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == self.primitives.list.obj_id(self.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
self.infer_list_subscript(target_ty, key)
|
||||||
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == self.primitives.ndarray.obj_id(self.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
self.infer_ndarray_subscript(target_ty, key)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
// Now `target_ty` either:
|
||||||
return report_error(
|
// 1) is a `TTuple`, or
|
||||||
"Tuple index must be a constant (KernelInvariant is also not supported)",
|
// 2) is simply not obvious for doing __getitem__ on.
|
||||||
slice.location,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// the index is not a constant, so value can only be a list-like structure
|
if let ExprKind::Constant { value: ast::Constant::Int(index), .. } = &key.node {
|
||||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
// If `key` is a constant int, then the value can be a sequence.
|
||||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
// Therefore, this can be handled by the unifier
|
||||||
self.constrain(
|
let getitem_ty = self.get_constant_index_item_type(target, *index, false)?;
|
||||||
slice.custom.unwrap(),
|
Ok(getitem_ty)
|
||||||
self.primitives.int32,
|
} else {
|
||||||
&slice.location,
|
// Out of ways to resolve __getitem__, throw an error.
|
||||||
)?;
|
report_error(
|
||||||
let list_tvar = iter_type_vars(params).nth(0).unwrap();
|
&format!(
|
||||||
let list = self
|
"'{}' cannot be indexed by this subscript",
|
||||||
.unifier
|
self.unifier.stringify(target_ty)
|
||||||
.subst(
|
),
|
||||||
self.primitives.list,
|
key.location,
|
||||||
&into_var_map([TypeVar { id: list_tvar.id, ty }]),
|
|
||||||
)
|
)
|
||||||
.unwrap();
|
|
||||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
|
||||||
Ok(ty)
|
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
}
|
||||||
let (_, ndims) =
|
}
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
}
|
||||||
|
|
||||||
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
|
/// Fold an item assignment, and return a type that constrains the type of RHS.
|
||||||
.into_iter()
|
fn infer_setitem_value_type(
|
||||||
.unique()
|
&mut self,
|
||||||
.collect_vec();
|
target: &ast::Expr<Option<Type>>,
|
||||||
let valid_index_ty = self
|
key: &ast::Expr<Option<Type>>,
|
||||||
.unifier
|
) -> Result<Type, InferenceError> {
|
||||||
.get_fresh_var_with_range(valid_index_tys.as_slice(), None, None)
|
let target_ty = target.custom.unwrap();
|
||||||
.ty;
|
match &*self.unifier.get_ty(target_ty) {
|
||||||
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
|
TypeEnum::TObj { obj_id, .. }
|
||||||
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
if *obj_id == self.primitives.list.obj_id(self.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// Handle list item assignment
|
||||||
|
|
||||||
|
// The expected value type is the same as the type of list.__getitem__
|
||||||
|
self.infer_list_subscript(target_ty, key)
|
||||||
}
|
}
|
||||||
_ => report_unscriptable_error(self.unifier),
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == self.primitives.ndarray.obj_id(self.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// Handle ndarray item assignment
|
||||||
|
|
||||||
|
// NOTE: `value` can either be an ndarray of or a scalar, even if `target` is an unsized ndarray.
|
||||||
|
|
||||||
|
// TODO: NumPy does automatic casting on `value`. (Currently not supported)
|
||||||
|
// See https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-indexed-arrays
|
||||||
|
|
||||||
|
let (scalar_ty, _) = unpack_ndarray_var_tys(self.unifier, target_ty);
|
||||||
|
let ndarray_ty =
|
||||||
|
make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None);
|
||||||
|
|
||||||
|
let expected_value_ty =
|
||||||
|
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray_ty], None, None).ty;
|
||||||
|
Ok(expected_value_ty)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Handle item assignments of other types.
|
||||||
|
|
||||||
|
// Now `target_ty` either:
|
||||||
|
// 1) is a `TTuple`, or
|
||||||
|
// 2) is simply not obvious for doing __setitem__ on.
|
||||||
|
|
||||||
|
if let ExprKind::Constant { value: ast::Constant::Int(index), .. } = &key.node {
|
||||||
|
// If `key` is a constant int, then the value can be a sequence.
|
||||||
|
// Therefore, this can be handled by the unifier
|
||||||
|
self.get_constant_index_item_type(target, *index, false)
|
||||||
|
} else {
|
||||||
|
// Out of ways to resolve __getitem__, throw an error.
|
||||||
|
report_error(
|
||||||
|
&format!(
|
||||||
|
"'{}' does not allow item assignment with this subscript",
|
||||||
|
self.unifier.stringify(target_ty)
|
||||||
|
),
|
||||||
|
key.location,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
@extern
|
||||||
|
def output_int32(x: int32):
|
||||||
|
...
|
||||||
|
|
||||||
|
@extern
|
||||||
|
def output_bool(x: bool):
|
||||||
|
...
|
||||||
|
|
||||||
|
def example1():
|
||||||
|
x, *ys, z = (1, 2, 3, 4, 5)
|
||||||
|
output_int32(x)
|
||||||
|
output_int32(ys[0])
|
||||||
|
output_int32(ys[1])
|
||||||
|
output_int32(ys[2])
|
||||||
|
output_int32(z)
|
||||||
|
|
||||||
|
def example2():
|
||||||
|
x, y, *zs = (1, 2, 3, 4, 5)
|
||||||
|
output_int32(x)
|
||||||
|
output_int32(y)
|
||||||
|
output_int32(zs[0])
|
||||||
|
output_int32(zs[1])
|
||||||
|
output_int32(zs[2])
|
||||||
|
|
||||||
|
def example3():
|
||||||
|
*xs, y, z = (1, 2, 3, 4, 5)
|
||||||
|
output_int32(xs[0])
|
||||||
|
output_int32(xs[1])
|
||||||
|
output_int32(xs[2])
|
||||||
|
output_int32(y)
|
||||||
|
output_int32(z)
|
||||||
|
|
||||||
|
def example4():
|
||||||
|
# Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements
|
||||||
|
x = [0, 1]
|
||||||
|
i = 0
|
||||||
|
i, x[i] = 1, 2 # i is updated, then x[i] is updated
|
||||||
|
output_int32(i)
|
||||||
|
output_int32(x[0])
|
||||||
|
output_int32(x[1])
|
||||||
|
|
||||||
|
class A:
|
||||||
|
value: int32
|
||||||
|
def __init__(self):
|
||||||
|
self.value = 1000
|
||||||
|
|
||||||
|
def example5():
|
||||||
|
ws = [88, 7, 8]
|
||||||
|
a = A()
|
||||||
|
x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,)
|
||||||
|
output_int32(x)
|
||||||
|
output_int32(y)
|
||||||
|
output_bool(ys[0])
|
||||||
|
output_int32(ys[1])
|
||||||
|
output_int32(a.value)
|
||||||
|
output_int32(ws[0])
|
||||||
|
output_int32(ws[1])
|
||||||
|
output_int32(ws[2])
|
||||||
|
|
||||||
|
def run() -> int32:
|
||||||
|
example1()
|
||||||
|
example2()
|
||||||
|
example3()
|
||||||
|
example4()
|
||||||
|
example5()
|
||||||
|
return 0
|
Loading…
Reference in New Issue