use crate::context::InferenceContext; use crate::expression_inference::{infer_expr, infer_simple_binding}; use crate::inference_core::resolve_call; use crate::magic_methods::binop_assign_name; use crate::primitives::*; use crate::typedef::{Type, TypeEnum::*}; use rustpython_parser::ast::*; pub fn check_stmts<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, stmts: &'b [Statement], ) -> Result { for stmt in stmts.iter() { match &stmt.node { StatementType::Assign { targets, value } => { check_assign(ctx, targets.as_slice(), &value)?; } StatementType::AugAssign { target, op, value } => { check_aug_assign(ctx, &target, op, &value)?; } StatementType::If { test, body, orelse } => { if check_if(ctx, test, body.as_slice(), orelse)? { return Ok(true) } } StatementType::While { test, body, orelse } => { check_while_stmt(ctx, test, body.as_slice(), orelse)?; } StatementType::For { is_async, target, iter, body, orelse, } => { if *is_async { return Err("async for is not supported".to_string()); } check_for_stmt(ctx, target, iter, body.as_slice(), orelse)?; } StatementType::Return { value } => { let result = ctx.get_result(); let t = if let Some(value) = value { infer_expr(ctx, value)? } else { None }; return if t == result { Ok(true) } else { Err("return type mismatch".to_string()) }; } StatementType::Continue | StatementType::Break => { continue; } _ => return Err("not supported".to_string()), } } Ok(false) } fn get_target_type<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, target: &'b Expression, ) -> Result { match &target.node { ExpressionType::Subscript { a, b } => { let int32 = ctx.get_primitive(INT32_TYPE); if infer_expr(ctx, &a)? == Some(int32) { let b = get_target_type(ctx, &b)?; if let ParametricType(LIST_TYPE, t) = b.as_ref() { Ok(t[0].clone()) } else { Err("subscript is only supported for list".to_string()) } } else { Err("subscript must be int32".to_string()) } } ExpressionType::Attribute { value, name } => { let t = get_target_type(ctx, &value)?; let base = t.get_base(ctx).ok_or_else(|| "no attributes".to_string())?; Ok(base .fields .get(name.as_str()) .ok_or_else(|| "no such attribute")? .clone()) } ExpressionType::Identifier { name } => Ok(ctx.resolve(name.as_str())?), _ => Err("not supported".to_string()), } } fn check_stmt_binding<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, target: &'b Expression, ty: Type, ) -> Result<(), String> { match &target.node { ExpressionType::Identifier { name } => { if name.as_str() == "_" { Ok(()) } else { match ctx.resolve(name.as_str()) { Ok(t) if t == ty => Ok(()), Err(_) => { ctx.assign(name.as_str(), ty).unwrap(); Ok(()) } _ => Err("conflicting type".into()), } } } ExpressionType::Tuple { elements } => { if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { if ls.len() != elements.len() { return Err("incorrect pattern length".into()); } for (x, y) in elements.iter().zip(ls.iter()) { check_stmt_binding(ctx, x, y.clone())?; } Ok(()) } else { Err("pattern matching supports tuple only".into()) } } _ => { let t = get_target_type(ctx, target)?; if ty == t { Ok(()) } else { Err("type mismatch".into()) } } } } fn check_assign<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, targets: &'b [Expression], value: &'b Expression, ) -> Result<(), String> { let ty = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?; for t in targets.iter() { check_stmt_binding(ctx, t, ty.clone())?; } Ok(()) } fn check_aug_assign<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, target: &'b Expression, op: &'b Operator, value: &'b Expression, ) -> Result<(), String> { let left = infer_expr(ctx, target)?.ok_or_else(|| "no value".to_string())?; let right = infer_expr(ctx, value)?.ok_or_else(|| "no value".to_string())?; let fun = binop_assign_name(op); resolve_call(ctx, Some(left), fun, &[right])?; Ok(()) } fn check_if<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, test: &'b Expression, body: &'b [Statement], orelse: &'b Option, ) -> Result { let boolean = ctx.get_primitive(BOOL_TYPE); let t = infer_expr(ctx, test)?; if t == Some(boolean) { let (names, result) = ctx.with_scope(|ctx| check_stmts(ctx, body)); let returned = result?; if let Some(orelse) = orelse { let (names2, result) = ctx.with_scope(|ctx| check_stmts(ctx, orelse.as_slice())); let returned = returned && result?; for (name, ty) in names.iter() { for (name2, ty2) in names2.iter() { if *name == *name2 && ty == ty2 { ctx.assign(name, ty.clone()).unwrap(); } } } Ok(returned) } else { Ok(false) } } else { Err("condition should be bool".to_string()) } } fn check_while_stmt<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, test: &'b Expression, body: &'b [Statement], orelse: &'b Option, ) -> Result { let boolean = ctx.get_primitive(BOOL_TYPE); let t = infer_expr(ctx, test)?; if t == Some(boolean) { // to check what variables are defined, we would have to do a graph analysis... // not implemented now let (_, result) = ctx.with_scope(|ctx| check_stmts(ctx, body)); result?; if let Some(orelse) = orelse { let (_, result) = ctx.with_scope(|ctx| check_stmts(ctx, orelse.as_slice())); result?; } // to check whether the loop returned on every possible path, we need to analyse the graph, // not implemented now Ok(false) } else { Err("condition should be bool".to_string()) } } fn check_for_stmt<'b: 'a, 'a>( ctx: &mut InferenceContext<'a>, target: &'b Expression, iter: &'b Expression, body: &'b [Statement], orelse: &'b Option, ) -> Result { let ty = infer_expr(ctx, iter)?.ok_or_else(|| "no value".to_string())?; if let ParametricType(LIST_TYPE, ls) = ty.as_ref() { let (_, result) = ctx.with_scope(|ctx| { infer_simple_binding(ctx, target, ls[0].clone())?; check_stmts(ctx, body) }); result?; if let Some(orelse) = orelse { let (_, result) = ctx.with_scope(|ctx| check_stmts(ctx, orelse.as_slice())); result?; } // to check whether the loop returned on every possible path, we need to analyse the graph, // not implemented now Ok(false) } else { Err("only list can be iterated over".to_string()) } } #[cfg(test)] mod test { use super::*; use crate::context::*; use indoc::indoc; use rustpython_parser::parser::parse_program; fn get_inference_context(ctx: TopLevelContext) -> InferenceContext { InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into()))) } #[test] fn test_assign() { let ctx = basic_ctx(); let mut ctx = get_inference_context(ctx); let ast = parse_program(indoc! {" a = 1 b = a * 2 " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" a = 1 b = b * 2 " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("unbounded identifier".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); let ast = parse_program(indoc! {" b = a = 1 b = b * 2 " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" b = a = 1 b = [a] " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("conflicting type".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); } #[test] fn test_if() { let ctx = basic_ctx(); let mut ctx = get_inference_context(ctx); let ast = parse_program(indoc! {" a = 1 b = a * 2 if b > a: c = 1 else: c = 0 d = c " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" a = 1 b = a * 2 if b > a: c = 1 else: d = 0 d = c " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("unbounded identifier".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); let ast = parse_program(indoc! {" a = 1 b = a * 2 if b > a: c = 1 d = c " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("unbounded identifier".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); let ast = parse_program(indoc! {" a = 1 b = a * 2 if a: b = 0 " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("condition should be bool".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); let ast = parse_program(indoc! {" a = 1 b = a * 2 if b > a: c = 1 c = [1] " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" a = 1 b = a * 2 if b > a: c = 1 else: c = 0 c = [1] " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("conflicting type".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); } #[test] fn test_while() { let ctx = basic_ctx(); let mut ctx = get_inference_context(ctx); let ast = parse_program(indoc! {" a = 1 b = 1 while a < 10: a += 1 b *= a " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" a = 1 b = 1 while a < 10: a += 1 b *= a " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" a = 1 b = 1 while a < 10: a += 1 b *= a else: a += 1 " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" a = 1 b = 1 while a: a += 1 " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("condition should be bool".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); let ast = parse_program(indoc! {" a = 1 b = 1 while a < 10: a += 1 c = a*2 else: c = a*2 b = c " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("unbounded identifier".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); } #[test] fn test_for() { let ctx = basic_ctx(); let mut ctx = get_inference_context(ctx); let ast = parse_program(indoc! {" b = 1 for a in [0, 1, 2, 3, 4, 5]: b *= a " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" b = 1 for a, a1 in [(0, 1), (2, 3), (4, 5)]: b *= a " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); } #[test] fn test_return() { let ctx = basic_ctx(); let mut ctx = get_inference_context(ctx); let ast = parse_program(indoc! {" b = 1 return " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(true), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" b = 1 if b > 0: return " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" b = 1 if b > 0: return else: return " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(true), check_stmts(ctx, ast.statements.as_slice())); }); let ast = parse_program(indoc! {" b = 1 while b > 0: return else: return " }) .unwrap(); ctx.with_scope(|ctx| { // with sophisticated analysis, this one should be Ok(true) // but with our simple implementation, this is Ok(false) // as we don't analyse the control flow assert_eq!(Ok(false), check_stmts(ctx, ast.statements.as_slice())); }); ctx.set_result(Some(ctx.get_primitive(INT32_TYPE))); let ast = parse_program(indoc! {" b = 1 return 1 " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!(Ok(true), check_stmts(ctx, ast.statements.as_slice())); }); ctx.set_result(Some(ctx.get_primitive(INT32_TYPE))); let ast = parse_program(indoc! {" b = 1 return [1] " }) .unwrap(); ctx.with_scope(|ctx| { assert_eq!( Err("return type mismatch".to_string()), check_stmts(ctx, ast.statements.as_slice()) ); }); } }