From fb9fe8edf2fb1bbcc95b13a05b2b46d7ffb9000b Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 2 Aug 2024 15:01:38 +0800 Subject: [PATCH] core: reimplement assignment type inference and codegen - distinguish between setitem and getitem - allow starred assignment targets, but the assigned value would be a tuple - allow both [...] and (...) to be target lists --- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/generator.rs | 36 +- nac3core/src/codegen/stmt.rs | 382 ++++++--- nac3core/src/typecheck/function_check.rs | 7 +- nac3core/src/typecheck/type_inferencer/mod.rs | 797 ++++++++++++------ nac3standalone/demo/src/assignment.py | 66 ++ 6 files changed, 884 insertions(+), 406 deletions(-) create mode 100644 nac3standalone/demo/src/assignment.py diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index f727ad8a0..66232d89d 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1122,7 +1122,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( ) .into_pointer_value(); let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); - generator.gen_assign(ctx, target, val.into())?; + generator.gen_assign(ctx, target, val.into(), elt.custom.unwrap())?; } _ => { panic!( diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index bb822f19d..6406ba84e 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -123,11 +123,45 @@ pub trait CodeGenerator { ctx: &mut CodeGenContext<'ctx, '_>, target: &Expr>, value: ValueEnum<'ctx>, + value_ty: Type, ) -> Result<(), String> where Self: Sized, { - gen_assign(self, ctx, target, value) + gen_assign(self, ctx, target, value, value_ty) + } + + /// Generate code for an assignment expression where LHS is a `"target_list"`. + /// + /// See . + fn gen_assign_target_list<'ctx>( + &mut self, + ctx: &mut CodeGenContext<'ctx, '_>, + targets: &Vec>>, + value: ValueEnum<'ctx>, + value_ty: Type, + ) -> Result<(), String> + where + Self: Sized, + { + gen_assign_target_list(self, ctx, targets, value, value_ty) + } + + /// Generate code for an item assignment. + /// + /// i.e., `target[key] = value` + fn gen_setitem<'ctx>( + &mut self, + ctx: &mut CodeGenContext<'ctx, '_>, + target: &Expr>, + key: &Expr>, + value: ValueEnum<'ctx>, + value_ty: Type, + ) -> Result<(), String> + where + Self: Sized, + { + gen_setitem(self, ctx, target, key, value, value_ty) } /// Generate code for a while expression. diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 1130cc183..93cee3bd0 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -10,10 +10,10 @@ use crate::{ expr::gen_binop_expr, gen_in_range_check, }, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{DefinitionId, TopLevelDef}, typecheck::{ magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, + typedef::{iter_type_vars, FunSignature, Type, TypeEnum}, }, }; use inkwell::{ @@ -23,10 +23,10 @@ use inkwell::{ values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, }; +use itertools::{izip, Itertools}; use nac3parser::ast::{ Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, }; -use std::convert::TryFrom; /// See [`CodeGenerator::gen_var_alloc`]. pub fn gen_var<'ctx>( @@ -97,8 +97,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( pattern: &Expr>, name: Option<&str>, ) -> Result>, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - // very similar to gen_expr, but we don't do an extra load at the end // and we flatten nested tuples Ok(Some(match &pattern.node { @@ -137,65 +135,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( } .unwrap() } - ExprKind::Subscript { value, slice, .. } => { - match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let v = generator - .gen_expr(ctx, value)? - .unwrap() - .to_basic_value_enum(ctx, generator, value.custom.unwrap())? - .into_pointer_value(); - let v = ListValue::from_ptr_val(v, llvm_usize, None); - let len = v.load_size(ctx, Some("len")); - let raw_index = generator - .gen_expr(ctx, slice)? - .unwrap() - .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? - .into_int_value(); - let raw_index = ctx - .builder - .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") - .unwrap(); - // handle negative index - let is_negative = ctx - .builder - .build_int_compare( - IntPredicate::SLT, - raw_index, - generator.get_size_type(ctx.ctx).const_zero(), - "is_neg", - ) - .unwrap(); - let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted").unwrap(); - let index = ctx - .builder - .build_select(is_negative, adjusted, raw_index, "index") - .map(BasicValueEnum::into_int_value) - .unwrap(); - // unsigned less than is enough, because negative index after adjustment is - // bigger than the length (for unsigned cmp) - let bound_check = ctx - .builder - .build_int_compare(IntPredicate::ULT, index, len, "inbound") - .unwrap(); - ctx.make_assert( - generator, - bound_check, - "0:IndexError", - "index {0} out of bounds 0:{1}", - [Some(raw_index), Some(len), None], - slice.location, - ); - v.data().ptr_offset(ctx, generator, &index, name) - } - - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - todo!() - } - - _ => unreachable!(), - } - } _ => unreachable!(), })) } @@ -206,70 +145,20 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( ctx: &mut CodeGenContext<'ctx, '_>, target: &Expr>, value: ValueEnum<'ctx>, + value_ty: Type, ) -> Result<(), String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - + // See https://docs.python.org/3/reference/simple_stmts.html#assignment-statements. match &target.node { - ExprKind::Tuple { elts, .. } => { - let BasicValueEnum::StructValue(v) = - value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? - else { - unreachable!() - }; - - for (i, elt) in elts.iter().enumerate() { - let v = ctx - .builder - .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") - .unwrap(); - generator.gen_assign(ctx, elt, v.into())?; - } + ExprKind::Subscript { value: target, slice: key, .. } => { + // Handle "slicing" or "subscription" + generator.gen_setitem(ctx, target, key, value, value_ty)?; } - ExprKind::Subscript { value: ls, slice, .. } - if matches!(&slice.node, ExprKind::Slice { .. }) => - { - let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() }; - - let ls = generator - .gen_expr(ctx, ls)? - .unwrap() - .to_basic_value_enum(ctx, generator, ls.custom.unwrap())? - .into_pointer_value(); - let ls = ListValue::from_ptr_val(ls, llvm_usize, None); - let Some((start, end, step)) = - handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))? - else { - return Ok(()); - }; - let value = value - .to_basic_value_enum(ctx, generator, target.custom.unwrap())? - .into_pointer_value(); - let value = ListValue::from_ptr_val(value, llvm_usize, None); - let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - *params.iter().next().unwrap().1 - } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 - } - _ => unreachable!(), - }; - - let ty = ctx.get_llvm_type(generator, ty); - let Some(src_ind) = handle_slice_indices( - &None, - &None, - &None, - ctx, - generator, - value.load_size(ctx, None), - )? - else { - return Ok(()); - }; - list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind); + ExprKind::Tuple { elts, .. } | ExprKind::List { elts, .. } => { + // Fold on `"[" [target_list] "]"` and `"(" [target_list] ")"` + generator.gen_assign_target_list(ctx, elts, value, value_ty)?; } _ => { + // Handle attribute and direct variable assignments. let name = if let ExprKind::Name { id, .. } = &target.node { format!("{id}.addr") } else { @@ -293,6 +182,232 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( Ok(()) } +pub fn gen_assign_target_list<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + targets: &Vec>>, + value: ValueEnum<'ctx>, + value_ty: Type, +) -> Result<(), String> { + // Deconstruct the tuple `value` + let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)? + else { + unreachable!() + }; + + // NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer. + let TypeEnum::TTuple { ty: tuple_tys } = &*ctx.unifier.get_ty(value_ty) else { + unreachable!(); + }; + + assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len()); + + let tuple = (0..tuple.get_type().count_fields()) + .map(|i| ctx.builder.build_extract_value(tuple, i, "item").unwrap()) + .collect_vec(); + + // Find the starred target if it exists. + let mut starred_target_index: Option = None; // Index of the "starred" target. If it exists, there may only be one. + for (i, target) in targets.iter().enumerate() { + if matches!(target.node, ExprKind::Starred { .. }) { + assert!(starred_target_index.is_none()); // The typechecker ensures this + starred_target_index = Some(i); + } + } + + if let Some(starred_target_index) = starred_target_index { + assert!(tuple_tys.len() >= targets.len() - 1); // The typechecker ensures this + + let a = starred_target_index; // Number of RHS values before the starred target + let b = tuple_tys.len() - (targets.len() - 1 - starred_target_index); // Number of RHS values after the starred target + // Thus `tuple[a..b]` is assigned to the starred target. + + // Handle assignment before the starred target + for (target, val, val_ty) in + izip!(&targets[..starred_target_index], &tuple[..a], &tuple_tys[..a]) + { + generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?; + } + + // Handle assignment to the starred target + if let ExprKind::Starred { value: target, .. } = &targets[starred_target_index].node { + let vals = &tuple[a..b]; + let val_tys = &tuple_tys[a..b]; + + // Create a sub-tuple from `value` for the starred target. + let sub_tuple_ty = ctx + .ctx + .struct_type(&vals.iter().map(BasicValueEnum::get_type).collect_vec(), false); + let psub_tuple_val = + ctx.builder.build_alloca(sub_tuple_ty, "starred_target_value_ptr").unwrap(); + for (i, val) in vals.iter().enumerate() { + let pitem = ctx + .builder + .build_struct_gep(psub_tuple_val, i as u32, "starred_target_value_item") + .unwrap(); + ctx.builder.build_store(pitem, *val).unwrap(); + } + let sub_tuple_val = + ctx.builder.build_load(psub_tuple_val, "starred_target_value").unwrap(); + + // Create the typechecker type of the sub-tuple + let sub_tuple_ty = ctx.unifier.add_ty(TypeEnum::TTuple { ty: val_tys.to_vec() }); + + // Now assign with that sub-tuple to the starred target. + generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?; + } else { + unreachable!() // The typechecker ensures this + } + + // Handle assignment after the starred target + for (target, val, val_ty) in + izip!(&targets[starred_target_index + 1..], &tuple[b..], &tuple_tys[b..]) + { + generator.gen_assign(ctx, target, ValueEnum::Dynamic(*val), *val_ty)?; + } + } else { + assert_eq!(tuple_tys.len(), targets.len()); // The typechecker ensures this + + for (target, val, val_ty) in izip!(targets, tuple, tuple_tys) { + generator.gen_assign(ctx, target, ValueEnum::Dynamic(val), *val_ty)?; + } + } + Ok(()) +} + +/// See [`CodeGenerator::gen_setitem`]. +pub fn gen_setitem<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target: &Expr>, + key: &Expr>, + value: ValueEnum<'ctx>, + value_ty: Type, +) -> Result<(), String> { + let target_ty = target.custom.unwrap(); + let key_ty = key.custom.unwrap(); + + match &*ctx.unifier.get_ty(target_ty) { + TypeEnum::TObj { obj_id, params: list_params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // Handle list item assignment + let llvm_usize = generator.get_size_type(ctx.ctx); + let target_item_ty = iter_type_vars(list_params).next().unwrap().ty; + + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)? + .into_pointer_value(); + let target = ListValue::from_ptr_val(target, llvm_usize, None); + + if let ExprKind::Slice { .. } = &key.node { + // Handle assigning to a slice + let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() }; + let Some((start, end, step)) = handle_slice_indices( + lower, + upper, + step, + ctx, + generator, + target.load_size(ctx, None), + )? + else { + return Ok(()); + }; + + let value = + value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value(); + let value = ListValue::from_ptr_val(value, llvm_usize, None); + + let target_item_ty = ctx.get_llvm_type(generator, target_item_ty); + let Some(src_ind) = handle_slice_indices( + &None, + &None, + &None, + ctx, + generator, + value.load_size(ctx, None), + )? + else { + return Ok(()); + }; + list_slice_assignment( + generator, + ctx, + target_item_ty, + target, + (start, end, step), + value, + src_ind, + ); + } else { + // Handle assigning to an index + let len = target.load_size(ctx, Some("len")); + + let index = generator + .gen_expr(ctx, key)? + .unwrap() + .to_basic_value_enum(ctx, generator, key_ty)? + .into_int_value(); + let index = ctx + .builder + .build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext") + .unwrap(); + + // handle negative index + let is_negative = ctx + .builder + .build_int_compare( + IntPredicate::SLT, + index, + generator.get_size_type(ctx.ctx).const_zero(), + "is_neg", + ) + .unwrap(); + let adjusted = ctx.builder.build_int_add(index, len, "adjusted").unwrap(); + let index = ctx + .builder + .build_select(is_negative, adjusted, index, "index") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + // unsigned less than is enough, because negative index after adjustment is + // bigger than the length (for unsigned cmp) + let bound_check = ctx + .builder + .build_int_compare(IntPredicate::ULT, index, len, "inbound") + .unwrap(); + ctx.make_assert( + generator, + bound_check, + "0:IndexError", + "index {0} out of bounds 0:{1}", + [Some(index), Some(len), None], + key.location, + ); + + // Write value to index on list + let item_ptr = + target.data().ptr_offset(ctx, generator, &index, Some("list_item_ptr")); + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + ctx.builder.build_store(item_ptr, value).unwrap(); + } + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + // Handle NDArray item assignment + todo!("ndarray subscript assignment is not yet implemented"); + } + _ => { + panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); + } + } + Ok(()) +} + /// See [`CodeGenerator::gen_for`]. pub fn gen_for( generator: &mut G, @@ -402,7 +517,7 @@ pub fn gen_for( .unwrap(); generator.gen_block(ctx, body.iter())?; } - TypeEnum::TObj { obj_id, .. } + TypeEnum::TObj { obj_id, params: list_params, .. } if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => { let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; @@ -442,8 +557,8 @@ pub fn gen_for( .map(BasicValueEnum::into_int_value) .unwrap(); let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); - - generator.gen_assign(ctx, target, val.into())?; + let val_ty = iter_type_vars(list_params).next().unwrap().ty; + generator.gen_assign(ctx, target, val.into(), val_ty)?; generator.gen_block(ctx, body.iter())?; } _ => { @@ -1604,14 +1719,14 @@ pub fn gen_stmt( } StmtKind::AnnAssign { target, value, .. } => { if let Some(value) = value { - let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; - generator.gen_assign(ctx, target, value)?; + let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) }; + generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?; } } StmtKind::Assign { targets, value, .. } => { - let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) }; + let Some(value_enum) = generator.gen_expr(ctx, value)? else { return Ok(()) }; for target in targets { - generator.gen_assign(ctx, target, value.clone())?; + generator.gen_assign(ctx, target, value_enum.clone(), value.custom.unwrap())?; } } StmtKind::Continue { .. } => { @@ -1625,15 +1740,16 @@ pub fn gen_stmt( StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::AugAssign { target, op, value, .. } => { - let value = gen_binop_expr( + let value_enum = gen_binop_expr( generator, ctx, target, Binop::aug_assign(*op), value, stmt.location, - )?; - generator.gen_assign(ctx, target, value.unwrap())?; + )? + .unwrap(); + generator.gen_assign(ctx, target, value_enum, value.custom.unwrap())?; } StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Raise { exc, .. } => { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index d296fc6d7..86fc8b8ee 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -34,13 +34,18 @@ impl<'a> Inferencer<'a> { self.should_have_value(pattern)?; Ok(()) } - ExprKind::Tuple { elts, .. } => { + ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { for elt in elts { self.check_pattern(elt, defined_identifiers)?; self.should_have_value(elt)?; } Ok(()) } + ExprKind::Starred { value, .. } => { + self.check_pattern(value, defined_identifiers)?; + self.should_have_value(value)?; + Ok(()) + } ExprKind::Subscript { value, slice, .. } => { self.check_expr(value, defined_identifiers)?; self.should_have_value(value)?; diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 88ae4d41b..0a98a575b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1,7 +1,6 @@ use std::collections::{HashMap, HashSet}; use std::convert::{From, TryInto}; use std::iter::once; -use std::ops::Not; use std::{cell::RefCell, sync::Arc}; use super::{ @@ -19,6 +18,7 @@ use crate::{ numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, TopLevelDef, }, + typecheck::typedef::Mapping, }; use itertools::{izip, Itertools}; use nac3parser::ast::{ @@ -123,6 +123,25 @@ fn report_type_error( Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()])) } +/// Traverse through a LHS expression in an assignment and set [`ExprContext`] to [`ExprContext::Store`] +/// when appropriate. +/// +/// nac3parser's `ExprContext` output is generally incorrect, and requires manual fixes. +fn fix_assignment_target_context(node: &mut ast::Located) { + match &mut node.node { + ExprKind::Name { ctx, .. } + | ExprKind::Attribute { ctx, .. } + | ExprKind::Subscript { ctx, .. } => { + *ctx = ExprContext::Store; + } + ExprKind::Tuple { ctx, elts } | ExprKind::List { ctx, elts } => { + *ctx = ExprContext::Store; + elts.iter_mut().for_each(fix_assignment_target_context); + } + _ => {} + } +} + impl<'a> Fold<()> for Inferencer<'a> { type TargetU = Option; type Error = InferenceError; @@ -131,18 +150,13 @@ impl<'a> Fold<()> for Inferencer<'a> { Ok(None) } - fn fold_stmt( - &mut self, - mut node: ast::Stmt<()>, - ) -> Result, Self::Error> { + fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result, Self::Error> { let stmt = match node.node { // we don't want fold over type annotation ast::StmtKind::AnnAssign { mut target, annotation, value, simple, config_comment } => { + fix_assignment_target_context(&mut target); // Fix parser bug + self.infer_pattern(&target)?; - // fix parser problem... - if let ExprKind::Attribute { ctx, .. } = &mut target.node { - *ctx = ExprContext::Store; - } let target = Box::new(self.fold_expr(*target)?); let value = if let Some(v) = value { @@ -304,69 +318,41 @@ impl<'a> Fold<()> for Inferencer<'a> { custom: None, } } - ast::StmtKind::Assign { ref mut targets, ref config_comment, .. } => { - for target in &mut *targets { - if let ExprKind::Attribute { ctx, .. } = &mut target.node { - *ctx = ExprContext::Store; - } - } - if targets.iter().all(|t| matches!(t.node, ExprKind::Name { .. })) { - let ast::StmtKind::Assign { targets, value, .. } = node.node else { - unreachable!() - }; + ast::StmtKind::Assign { mut targets, type_comment, config_comment, value, .. } => { + // Fix parser bug + targets.iter_mut().for_each(fix_assignment_target_context); - let value = self.fold_expr(*value)?; - let value_ty = value.custom.unwrap(); - let targets: Result, _> = targets - .into_iter() - .map(|target| { - let ExprKind::Name { id, ctx } = target.node else { unreachable!() }; + // NOTE: Do not register identifiers into `self.defined_identifiers` before checking targets + // and value, otherwise the Inferencer might use undefined variables in `self.defined_identifiers` + // and produce strange errors. - self.defined_identifiers.insert(id); - let target_ty = if let Some(ty) = self.variable_mapping.get(&id) { - *ty - } else { - let unifier: &mut Unifier = self.unifier; - self.function_data - .resolver - .get_symbol_type( - unifier, - &self.top_level.definitions.read(), - self.primitives, - id, - ) - .unwrap_or_else(|_| { - self.variable_mapping.insert(id, value_ty); - value_ty - }) - }; - let location = target.location; - self.unifier.unify(value_ty, target_ty).map(|()| Located { - location, - node: ExprKind::Name { id, ctx }, - custom: Some(target_ty), - }) - }) - .collect(); - let loc = node.location; - let targets = targets.map_err(|e| { - HashSet::from([e.at(Some(loc)).to_display(self.unifier).to_string()]) - })?; - return Ok(Located { - location: node.location, - node: ast::StmtKind::Assign { - targets, - value: Box::new(value), - type_comment: None, - config_comment: config_comment.clone(), - }, - custom: None, - }); - } - for target in targets { + let value = self.fold_expr(*value)?; + + let targets: Vec<_> = targets + .into_iter() + .map(|target| -> Result<_, InferenceError> { + // In cases like `x = y = z = rhs`, `rhs`'s type will be constrained by + // the intersection of `x`, `y`, and `z` here. + let target = self.fold_assign_target(target, value.custom.unwrap())?; + Ok(target) + }) + .try_collect()?; + + // Do this only after folding targets and value + for target in &targets { self.infer_pattern(target)?; } - fold::fold_stmt(self, node)? + + Located { + location: node.location, + node: ast::StmtKind::Assign { + targets, + type_comment, + config_comment, + value: Box::new(value), + }, + custom: None, + } } ast::StmtKind::With { ref items, .. } => { for item in items { @@ -379,7 +365,8 @@ impl<'a> Fold<()> for Inferencer<'a> { _ => fold::fold_stmt(self, node)?, }; match &stmt.node { - ast::StmtKind::AnnAssign { .. } + ast::StmtKind::Assign { .. } + | ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Break { .. } | ast::StmtKind::Continue { .. } | ast::StmtKind::Expr { .. } @@ -389,11 +376,6 @@ impl<'a> Fold<()> for Inferencer<'a> { ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; } - ast::StmtKind::Assign { targets, value, .. } => { - for target in targets { - self.unify(target.custom.unwrap(), value.custom.unwrap(), &target.location)?; - } - } ast::StmtKind::Raise { exc, cause, .. } => { if let Some(cause) = cause { return report_error("raise ... from cause is not supported", cause.location); @@ -533,6 +515,7 @@ impl<'a> Fold<()> for Inferencer<'a> { } _ => fold::fold_expr(self, node)?, }; + let custom = match &expr.node { ExprKind::Constant { value, .. } => Some(self.infer_constant(value, &expr.location)?), ExprKind::Name { id, .. } => { @@ -580,8 +563,6 @@ impl<'a> Fold<()> for Inferencer<'a> { Some(self.infer_identifier(*id)?) } } - ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), - ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ExprKind::Attribute { value, attr, ctx } => { Some(self.infer_attribute(value, *attr, *ctx)?) } @@ -595,8 +576,10 @@ impl<'a> Fold<()> for Inferencer<'a> { ExprKind::Compare { left, ops, comparators } => { Some(self.infer_compare(expr.location, left, ops, comparators)?) } - ExprKind::Subscript { value, slice, ctx, .. } => { - Some(self.infer_subscript(value.as_ref(), slice.as_ref(), *ctx)?) + ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), + ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), + ExprKind::Subscript { value, slice, .. } => { + Some(self.infer_getitem(value.as_ref(), slice.as_ref())?) } ExprKind::IfExp { test, body, orelse } => { Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?) @@ -629,7 +612,7 @@ impl<'a> Inferencer<'a> { }) } - fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), InferenceError> { + fn infer_pattern(&mut self, pattern: &ast::Expr) -> Result<(), InferenceError> { match &pattern.node { ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { @@ -643,6 +626,13 @@ impl<'a> Inferencer<'a> { } Ok(()) } + ExprKind::List { elts, .. } => { + for elt in elts { + self.infer_pattern(elt)?; + } + Ok(()) + } + ExprKind::Starred { value, .. } => self.infer_pattern(value), _ => Ok(()), } } @@ -1943,28 +1933,270 @@ impl<'a> Inferencer<'a> { Ok(res.unwrap()) } - /// Infers the type of a subscript expression on an `ndarray`. - fn infer_subscript_ndarray( + /// Fold an assignment `"target_list"` recursively, and check RHS's type. + /// See definition of `"target_list"` in . + fn fold_assign_target_list( &mut self, - value: &ast::Expr>, - slice: &ast::Expr>, - dummy_tvar: Type, - ndims: Type, - ) -> InferenceResult { - debug_assert!(matches!( - &*self.unifier.get_ty_immutable(dummy_tvar), - TypeEnum::TVar { is_const_generic: false, .. } - )); - - let constrained_ty = - make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims)); - self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; - - let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else { - panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims)) + target_list_location: &Location, + mut targets: Vec>, + rhs_ty: Type, + ) -> Result>>, InferenceError> { + // TODO: Allow bidirectional typechecking? Currently RHS's type has to be resolved. + let TypeEnum::TTuple { ty: rhs_tys } = &*self.unifier.get_ty(rhs_ty) else { + // TODO: Allow RHS AST-aware error reporting + return report_error( + "LHS target list pattern requires RHS to be a tuple type", + *target_list_location, + ); }; - let ndims = values + // Find the starred target if it exists. + let mut starred_target_index: Option = None; // Index of the "starred" target. If it exists, there may only be one. + for (i, target) in targets.iter().enumerate() { + if matches!(target.node, ExprKind::Starred { .. }) { + if starred_target_index.is_none() { + // First "starred" target found. + starred_target_index = Some(i); + } else { + // Second "starred" targets found. This is an error. + return report_error( + "there can only be one starred target, but found another one", + target.location, + ); + } + } + } + + let mut folded_targets: Vec>> = Vec::new(); + if let Some(starred_target_index) = starred_target_index { + if rhs_tys.len() < targets.len() - 1 { + /* + Rules: + ``` + (x, *ys, z) = (1,) # error + (x, *ys, z) = (1, 2) # ok, ys = () + (x, *ys, z) = (1, 2, 3) # ok, ys = (2,) + ``` + */ + return report_error( + &format!( + "Target list pattern requires RHS tuple type have to at least {} element(s), but RHS only has {} element(s)", + targets.len() - 1, + rhs_tys.len() + ), + *target_list_location + ); + } + + /* + (a, b, c, ..., *xs, ..., x, y, z) + before ^^^^^^^^^^^^ ^^^ ^^^^^^^^^^^^ after + starred + */ + + let targets_after = targets.drain(starred_target_index + 1..).collect_vec(); + let target_starred = targets.pop().unwrap(); + let targets_before = targets; + + let a = targets_before.len(); + let b = rhs_tys.len() - targets_after.len(); + + let rhs_tys_before = &rhs_tys[..a]; + let rhs_tys_starred = &rhs_tys[a..b]; + let rhs_tys_after = &rhs_tys[b..]; + + // Fold before the starred target + for (target, rhs_ty) in izip!(targets_before, rhs_tys_before) { + folded_targets.push(self.fold_assign_target(target, *rhs_ty)?); + } + + // Fold the starred target + if let ExprKind::Starred { value: target, .. } = target_starred.node { + let ty = self.unifier.add_ty(TypeEnum::TTuple { ty: rhs_tys_starred.to_vec() }); + let folded_target = self.fold_assign_target(*target, ty)?; + folded_targets.push(Located { + location: target_starred.location, + node: ExprKind::Starred { + value: Box::new(folded_target), + ctx: ExprContext::Store, + }, + custom: None, + }); + } else { + unreachable!() + } + + // Fold after the starred target + for (target, rhs_ty) in izip!(targets_after, rhs_tys_after) { + folded_targets.push(self.fold_assign_target(target, *rhs_ty)?); + } + } else { + // Fold target list without a "starred" target. + if rhs_tys.len() != targets.len() { + return report_error( + &format!( + "Target list pattern requires RHS tuple type have to {} element(s), but RHS only has {} element(s)", + targets.len() - 1, + rhs_tys.len() + ), + *target_list_location + ); + } + + for (target, rhs_ty) in izip!(targets, rhs_tys) { + folded_targets.push(self.fold_assign_target(target, *rhs_ty)?); + } + } + + Ok(folded_targets) + } + + /// Fold an assignment "target" recursively, and check RHS's type. + /// See definition of "target" in . + fn fold_assign_target( + &mut self, + target: ast::Expr<()>, + rhs_ty: Type, + ) -> Result>, InferenceError> { + match target.node { + ExprKind::Name { id, .. } => { + // Fold on "identifier" + match self.variable_mapping.get(&id) { + None => { + // Assigning to a new variable name; RHS's type could be anything. + let expected_rhs_ty = self + .unifier + .get_fresh_var( + Some(format!("type_of_{id}").into()), + Some(target.location), + ) + .ty; + self.variable_mapping.insert(id, expected_rhs_ty); // Register new variable + self.constrain(rhs_ty, expected_rhs_ty, &target.location)?; + } + Some(expected_rhs_ty) => { + // Re-assigning to an existing variable name. + self.constrain(rhs_ty, *expected_rhs_ty, &target.location)?; + } + }; + Ok(Located { + location: target.location, + node: ExprKind::Name { id, ctx: ExprContext::Store }, + custom: Some(rhs_ty), // Type info is needed here because of the CodeGenerator. + }) + } + ExprKind::Attribute { .. } => { + // Fold on "attributeref" + let pattern = self.fold_expr(target)?; + let expected_rhs_ty = pattern.custom.unwrap(); + self.constrain(rhs_ty, expected_rhs_ty, &pattern.location)?; + Ok(pattern) + } + ExprKind::Subscript { value: target, slice: key, .. } => { + // Fold on "slicing" or "subscription" + // TODO: Make `__setitem__` a general object field like `__add__` in NAC3? + let target = self.fold_expr(*target)?; + let key = self.fold_expr(*key)?; + + let expected_rhs_ty = self.infer_setitem_value_type(&target, &key)?; + self.constrain(rhs_ty, expected_rhs_ty, &target.location)?; + + Ok(Located { + location: target.location, + node: ExprKind::Subscript { + value: Box::new(target), + slice: Box::new(key), + ctx: ExprContext::Store, + }, + custom: None, // We don't need to know the type of `target[key]` + }) + } + ExprKind::List { elts, .. } => { + // Fold on `"[" [target_list] "]"` + let elts = self.fold_assign_target_list(&target.location, elts, rhs_ty)?; + Ok(Located { + location: target.location, + node: ExprKind::List { ctx: ExprContext::Store, elts }, + custom: None, + }) + } + ExprKind::Tuple { elts, .. } => { + // Fold on `"(" [target_list] ")"` + let elts = self.fold_assign_target_list(&target.location, elts, rhs_ty)?; + Ok(Located { + location: target.location, + node: ExprKind::Tuple { ctx: ExprContext::Store, elts }, + custom: None, + }) + } + ExprKind::Starred { .. } => report_error( + "starred assignment target must be in a list or tuple", + target.location, + ), + _ => report_error("encountered unsupported/illegal LHS pattern", target.location), + } + } + + /// Typecheck the subscript slice indexing into an ndarray. + /// + /// That is: + /// ```python + /// my_ndarray[::-2, 1, :, None, 9:23] + /// ^^^^^^^^^^^^^^^^^^^^^^ this + /// ``` + /// + /// The number of dimensions to subtract from the ndarray being indexed is also calculated and returned, + /// it could even be negative when more axes are added because of `None`. + fn fold_ndarray_subscript_slice( + &mut self, + slice: &ast::Expr>, + ) -> Result { + // TODO: Handle `None` / `np.newaxis` + + // Flatten `slice` into subscript indices. + let indices = match &slice.node { + ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(), + _ => vec![slice], + }; + + // Typecheck the subscript indices. + // We will also take the opportunity to deduce `dims_to_subtract` as well + let mut dims_to_subtract: i128 = 0; + for index in indices { + if let ExprKind::Slice { lower, upper, step } = &index.node { + for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { + self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; + } + } else { + // Treat anything else as an integer index, and force unify their type to int32. + self.unify(index.custom.unwrap(), self.primitives.int32, &index.location)?; + dims_to_subtract += 1; + } + } + + Ok(dims_to_subtract) + } + + /// Check if the `ndims` [`Type`] of an ndarray is valid (e.g., no negative values), + /// and attempt to subtract `ndims` by `dims_to_subtract` and return subtracted `ndims`. + /// + /// `dims_to_subtract` can be set to `0` if you only want to check if `ndims` is valid. + fn check_ndarray_ndims_and_subtract( + &mut self, + target_ty: Type, + ndims: Type, + dims_to_subtract: i128, + ) -> Result { + // Typecheck `ndims`. + let TypeEnum::TLiteral { values: ndims, .. } = &*self.unifier.get_ty_immutable(ndims) + else { + panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims)) + }; + assert!(!ndims.is_empty()); + + // Check if there are negative literals. + // NOTE: Don't mix this with subtracting dims, otherwise the user errors could be confusing. + let ndims = ndims .iter() .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) .collect::, _>>() @@ -1975,204 +2207,229 @@ impl<'a> Inferencer<'a> { )]) })?; - assert!(!ndims.is_empty()); + // Infer the new `ndims` after indexing the ndarray with `slice`. + // Disallow subscripting if any Literal value will subscript on an element. + let new_ndims = ndims + .into_iter() + .map(|v| { + let v = i128::from(v) - dims_to_subtract; + u64::try_from(v) + }) + .collect::, _>>() + .map_err(|_| { + HashSet::from([format!( + "Cannot subscript {} by {dims_to_subtract} dimension(s)", + self.unifier.stringify(target_ty), + )]) + })?; - // The number of dimensions subscripted by the index expression. - // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a - // dimension will remove a dimension. - let subscripted_dims = match &slice.node { - ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { - if let ExprKind::Slice { .. } = &value_subexpr.node { - acc - } else { - acc + 1 - } - }), + let new_ndims_ty = self + .unifier + .get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None); - ExprKind::Slice { .. } => 0, - _ => 1, + Ok(new_ndims_ty) + } + + /// Infer the type of the result of indexing into an ndarray. + /// + /// * `ndarray_ty` - The [`Type`] of the ndarray being indexed into. + /// * `slice` - The subscript expression indexing into the ndarray. + fn infer_ndarray_subscript( + &mut self, + ndarray_ty: Type, + slice: &ast::Expr>, + ) -> InferenceResult { + let (dtype, ndims) = unpack_ndarray_var_tys(self.unifier, ndarray_ty); + + let dims_to_substract = self.fold_ndarray_subscript_slice(slice)?; + let new_ndims = + self.check_ndarray_ndims_and_subtract(ndarray_ty, ndims, dims_to_substract)?; + + // Now we need extra work to check `new_ndims` to see if the user has indexed into a single element. + + let TypeEnum::TLiteral { values: new_ndims_values, .. } = &*self.unifier.get_ty(new_ndims) + else { + unreachable!("infer_ndarray_ndims should always return TLiteral") }; - if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { - // ndarray[T, Literal[1]] - Non-Slice index always returns an object of type T + let new_ndims_values = new_ndims_values + .iter() + .map(|v| u64::try_from(v.clone()).expect("new_ndims should be convertible to u64")) + .collect_vec(); - assert_ne!(ndims[0], 0); - - Ok(dummy_tvar) + if new_ndims_values.len() == 1 && new_ndims_values[0] == 0 { + // The subscripted ndarray must be unsized + // The user must be indexing into a single element + Ok(dtype) } else { - // Otherwise - Index returns an object of type ndarray[T, Literal[N - subscripted_dims]] + // The subscripted ndarray is not unsized / may not be unsized. (i.e., may or may not have indexed into a single element) - // Disallow subscripting if any Literal value will subscript on an element - let new_ndims = ndims - .into_iter() - .map(|v| { - let v = i128::from(v) - i128::from(subscripted_dims); - u64::try_from(v) - }) - .collect::, _>>() - .map_err(|_| { - HashSet::from([format!( - "Cannot subscript {} by {subscripted_dims} dimensions", - self.unifier.stringify(value.custom.unwrap()), - )]) - })?; - - if new_ndims.iter().any(|v| *v == 0) { + if new_ndims_values.iter().any(|v| *v == 0) { + // TODO: Difficult to implement since now the return may both be a scalar type, or an ndarray type. unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented") } - let ndims_ty = self - .unifier - .get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None); - let subscripted_ty = - make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty)); - - Ok(subscripted_ty) + let new_ndarray_ty = + make_ndarray_ty(self.unifier, self.primitives, Some(dtype), Some(new_ndims)); + Ok(new_ndarray_ty) } } - fn infer_subscript( + /// Infer the type of the result of indexing into a list. + /// + /// * `list_ty` - The [`Type`] of the list being indexed into. + /// * `key` - The subscript expression indexing into the list. + fn infer_list_subscript( &mut self, - value: &ast::Expr>, - slice: &ast::Expr>, - ctx: ExprContext, + list_ty: Type, + key: &ast::Expr>, + ) -> Result { + let TypeEnum::TObj { params: list_params, .. } = &*self.unifier.get_ty(list_ty) else { + unreachable!() + }; + let item_ty = iter_type_vars(list_params).nth(0).unwrap().ty; + + if let ExprKind::Slice { lower, upper, step } = &key.node { + // Typecheck on the slice + for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { + let v_ty = v.custom.unwrap(); + self.constrain(v_ty, self.primitives.int32, &v.location)?; + } + Ok(list_ty) // type list[T] + } else { + // Treat anything else as an integer index, and force unify their type to int32. + self.constrain(key.custom.unwrap(), self.primitives.int32, &key.location)?; + Ok(item_ty) // type T + } + } + + /// Generate a type that constrains the type of `target` to have a `__getitem__` at `index`. + /// + /// * `target` - The target being indexed by `index`. + /// * `index` - The constant index. + /// * `mutable` - Should the constraint be mutable or immutable? + fn get_constant_index_item_type( + &mut self, + target: &ast::Expr>, + index: i128, + mutable: bool, ) -> InferenceResult { - let report_unscriptable_error = |unifier: &mut Unifier| { - // User is attempting to index into a value of an unsupported type. - - let value_ty = value.custom.unwrap(); - let value_ty_str = unifier.stringify(value_ty); - - return report_error( - format!("'{value_ty_str}' object is not subscriptable").as_str(), - slice.location, // using the slice's location (rather than value's) because it is more clear - ); + let Ok(index) = i32::try_from(index) else { + return Err(HashSet::from(["Index must be int32".to_string()])); }; - let ty = self.unifier.get_dummy_var().ty; - match &slice.node { - ExprKind::Slice { lower, upper, step } => { - for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { - self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; - } - let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - let list_tvar = iter_type_vars(params).nth(0).unwrap(); - self.unifier - .subst( - self.primitives.list, - &into_var_map([TypeVar { id: list_tvar.id, ty }]), - ) - .unwrap() - } + let item_ty = self.unifier.get_dummy_var().ty; // To be resolved by the unifier - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + // Constrain `target` + let fields_constrain = Mapping::from_iter([( + RecordKey::Int(index), + RecordField::new(item_ty, mutable, Some(target.location)), + )]); + let fields_constrain_ty = self.unifier.add_record(fields_constrain); + self.constrain(target.custom.unwrap(), fields_constrain_ty, &target.location)?; - make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) - } + Ok(item_ty) + } - _ => { - return report_unscriptable_error(self.unifier); - } - }; - self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; - Ok(list_like_ty) + /// Infer the return type of a `__getitem__` expression. + /// + /// i.e., `target[key]`, where the [`ExprContext`] is [`ExprContext::Load`]. + fn infer_getitem( + &mut self, + target: &ast::Expr>, + key: &ast::Expr>, + ) -> InferenceResult { + let target_ty = target.custom.unwrap(); + + match &*self.unifier.get_ty(target_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == self.primitives.list.obj_id(self.unifier).unwrap() => + { + self.infer_list_subscript(target_ty, key) } - ExprKind::Constant { value: ast::Constant::Int(val), .. } => { - match &*self.unifier.get_ty(value.custom.unwrap()) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); - self.infer_subscript_ndarray(value, slice, ty, ndims) - } - _ => { - // the index is a constant, so value can be a sequence. - let ind: Option = (*val).try_into().ok(); - let ind = - ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?; - let map = once(( - ind.into(), - RecordField::new(ty, ctx == ExprContext::Store, Some(value.location)), - )) - .collect(); - let seq = self.unifier.add_record(map); - self.constrain(value.custom.unwrap(), seq, &value.location)?; - Ok(ty) - } - } - } - ExprKind::Tuple { elts, .. } => { - if value - .custom - .unwrap() - .obj_id(self.unifier) - .is_some_and(|id| id == PrimDef::NDArray.id()) - .not() - { - return report_error( - "Tuple slices are only supported for ndarrays", - slice.location, - ); - } - - for elt in elts { - if let ExprKind::Slice { lower, upper, step } = &elt.node { - for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { - self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; - } - } else { - self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?; - } - } - - let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); - self.infer_subscript_ndarray(value, slice, ty, ndims) + TypeEnum::TObj { obj_id, .. } + if *obj_id == self.primitives.ndarray.obj_id(self.unifier).unwrap() => + { + self.infer_ndarray_subscript(target_ty, key) } _ => { - if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { - return report_error( - "Tuple index must be a constant (KernelInvariant is also not supported)", - slice.location, - ); + // Now `target_ty` either: + // 1) is a `TTuple`, or + // 2) is simply not obvious for doing __getitem__ on. + + if let ExprKind::Constant { value: ast::Constant::Int(index), .. } = &key.node { + // If `key` is a constant int, then the value can be a sequence. + // Therefore, this can be handled by the unifier + let getitem_ty = self.get_constant_index_item_type(target, *index, false)?; + Ok(getitem_ty) + } else { + // Out of ways to resolve __getitem__, throw an error. + report_error( + &format!( + "'{}' cannot be indexed by this subscript", + self.unifier.stringify(target_ty) + ), + key.location, + ) } + } + } + } - // the index is not a constant, so value can only be a list-like structure - match &*self.unifier.get_ty(value.custom.unwrap()) { - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - self.constrain( - slice.custom.unwrap(), - self.primitives.int32, - &slice.location, - )?; - let list_tvar = iter_type_vars(params).nth(0).unwrap(); - let list = self - .unifier - .subst( - self.primitives.list, - &into_var_map([TypeVar { id: list_tvar.id, ty }]), - ) - .unwrap(); - self.constrain(value.custom.unwrap(), list, &value.location)?; - Ok(ty) - } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (_, ndims) = - unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); + /// Fold an item assignment, and return a type that constrains the type of RHS. + fn infer_setitem_value_type( + &mut self, + target: &ast::Expr>, + key: &ast::Expr>, + ) -> Result { + let target_ty = target.custom.unwrap(); + match &*self.unifier.get_ty(target_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == self.primitives.list.obj_id(self.unifier).unwrap() => + { + // Handle list item assignment - let valid_index_tys = [self.primitives.int32, self.primitives.isize()] - .into_iter() - .unique() - .collect_vec(); - let valid_index_ty = self - .unifier - .get_fresh_var_with_range(valid_index_tys.as_slice(), None, None) - .ty; - self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?; - self.infer_subscript_ndarray(value, slice, ty, ndims) - } - _ => report_unscriptable_error(self.unifier), + // The expected value type is the same as the type of list.__getitem__ + self.infer_list_subscript(target_ty, key) + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == self.primitives.ndarray.obj_id(self.unifier).unwrap() => + { + // Handle ndarray item assignment + + // NOTE: `value` can either be an ndarray of or a scalar, even if `target` is an unsized ndarray. + + // TODO: NumPy does automatic casting on `value`. (Currently not supported) + // See https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-indexed-arrays + + let (scalar_ty, _) = unpack_ndarray_var_tys(self.unifier, target_ty); + let ndarray_ty = + make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); + + let expected_value_ty = + self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray_ty], None, None).ty; + Ok(expected_value_ty) + } + _ => { + // Handle item assignments of other types. + + // Now `target_ty` either: + // 1) is a `TTuple`, or + // 2) is simply not obvious for doing __setitem__ on. + + if let ExprKind::Constant { value: ast::Constant::Int(index), .. } = &key.node { + // If `key` is a constant int, then the value can be a sequence. + // Therefore, this can be handled by the unifier + self.get_constant_index_item_type(target, *index, false) + } else { + // Out of ways to resolve __getitem__, throw an error. + report_error( + &format!( + "'{}' does not allow item assignment with this subscript", + self.unifier.stringify(target_ty) + ), + key.location, + ) } } } diff --git a/nac3standalone/demo/src/assignment.py b/nac3standalone/demo/src/assignment.py new file mode 100644 index 000000000..c87ed03a6 --- /dev/null +++ b/nac3standalone/demo/src/assignment.py @@ -0,0 +1,66 @@ +@extern +def output_int32(x: int32): + ... + +@extern +def output_bool(x: bool): + ... + +def example1(): + x, *ys, z = (1, 2, 3, 4, 5) + output_int32(x) + output_int32(ys[0]) + output_int32(ys[1]) + output_int32(ys[2]) + output_int32(z) + +def example2(): + x, y, *zs = (1, 2, 3, 4, 5) + output_int32(x) + output_int32(y) + output_int32(zs[0]) + output_int32(zs[1]) + output_int32(zs[2]) + +def example3(): + *xs, y, z = (1, 2, 3, 4, 5) + output_int32(xs[0]) + output_int32(xs[1]) + output_int32(xs[2]) + output_int32(y) + output_int32(z) + +def example4(): + # Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements + x = [0, 1] + i = 0 + i, x[i] = 1, 2 # i is updated, then x[i] is updated + output_int32(i) + output_int32(x[0]) + output_int32(x[1]) + +class A: + value: int32 + def __init__(self): + self.value = 1000 + +def example5(): + ws = [88, 7, 8] + a = A() + x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,) + output_int32(x) + output_int32(y) + output_bool(ys[0]) + output_int32(ys[1]) + output_int32(a.value) + output_int32(ws[0]) + output_int32(ws[1]) + output_int32(ws[2]) + +def run() -> int32: + example1() + example2() + example3() + example4() + example5() + return 0