diff --git a/nac3core/src/typecheck/expression_inference.rs b/nac3core/src/typecheck/expression_inference.rs index 510b49b..7c80d0e 100644 --- a/nac3core/src/typecheck/expression_inference.rs +++ b/nac3core/src/typecheck/expression_inference.rs @@ -266,33 +266,43 @@ impl<'a> InferenceContext<'a> { } fn infer_subscript(&self, value: &Box>>, slice: &Box>>) -> Result, String> { - let t = if let TypeEnum::ParametricType(primitives::LIST_TYPE, ls) = value.custom.as_ref().ok_or_else(|| "no value".to_string())?.as_ref() { - ls[0].clone() - } else { - return Err("subscript is not supported for types other than list".into()); - }; - - if let ast::ExprKind::Slice {lower, upper, step} = &slice.node { - let int32_type = self.get_primitive(primitives::INT32_TYPE); - let l = lower.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("lower bound cannot be typped".to_string()))?; - let u = upper.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("upper bound cannot be typped".to_string()))?; - let s = step.as_ref().map_or( - Ok(&int32_type), - |x| x.custom.as_ref().ok_or("step cannot be typped".to_string()))?; - - if l == &int32_type && u == &int32_type && s == &int32_type { - Ok(value.custom.clone()) + let val_type = value.custom.as_ref().ok_or_else(|| "no value".to_string())?.as_ref(); + if let TypeEnum::ParametricType(primitives::LIST_TYPE, ls) = val_type { + if let ast::ExprKind::Slice {lower, upper, step} = &slice.node { + let int32_type = self.get_primitive(primitives::INT32_TYPE); + let l = lower.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("lower bound cannot be typped".to_string()))?; + let u = upper.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("upper bound cannot be typped".to_string()))?; + let s = step.as_ref().map_or( + Ok(&int32_type), + |x| x.custom.as_ref().ok_or("step cannot be typped".to_string()))?; + + if l == &int32_type && u == &int32_type && s == &int32_type { + Ok(value.custom.clone()) + } else { + Err("slice must be int32 type".into()) + } + } else if slice.custom == Some(self.get_primitive(primitives::INT32_TYPE)) { + Ok(Some(ls[0].clone())) } else { - Err("slice must be int32 type".into()) + Err("slice or index must be int32 type".into()) + } + } else if let TypeEnum::ParametricType(primitives::TUPLE_TYPE, ls) = val_type { + if let ast::ExprKind::Constant {kind: _, value: ast::Constant::Int(val)} = &slice.node { + let ind: Result = val.try_into(); + if ind.is_ok() && ind.unwrap() < ls.len() { + Ok(Some(ls[ind.unwrap()].clone())) + } else { + Err("tuple constant index out of range".into()) + } + } else { + Err("tuple index can only be constant".into()) } - } else if slice.custom == Some(self.get_primitive(primitives::INT32_TYPE)) { - Ok(Some(t)) } else { - Err("slice or index must be int32 type".into()) + Err("subscript is not supported for types other than list or tuple".into()) } } @@ -631,6 +641,7 @@ pub mod test { let ast8 = rustpython_parser::parser::parse_expression("[[(1, 2), (2, 3), (3, 4)], [(2, 4), (4, 6)]][0]").unwrap(); let ast9 = rustpython_parser::parser::parse_expression("[1, 2, 3, 4, 5][1: 2]").unwrap(); let ast10 = rustpython_parser::parser::parse_expression("4 if False and True else 8").unwrap(); + let ast11 = rustpython_parser::parser::parse_expression("(1, 2, 3, 4)[1]").unwrap(); let folded = inf.fold_expr(ast1).unwrap(); let folded_2 = Premapper.fold_expr(ast2).unwrap(); @@ -642,6 +653,7 @@ pub mod test { let folded_8 = inf.fold_expr(ast8).unwrap(); let folded_9 = inf.fold_expr(ast9).unwrap(); let folded_10 = inf.fold_expr(ast10).unwrap(); + let folded_11 = inf.fold_expr(ast11).unwrap(); println!("{:?}", folded.custom); println!("{:?}", folded_2.custom); @@ -653,6 +665,7 @@ pub mod test { println!("{:?}", folded_8.custom); println!("{:?}", folded_9.custom); println!("{:?}", folded_10.custom); + println!("{:?}", folded_11.custom); } } \ No newline at end of file