diff --git a/nac3core/src/typecheck/expression_inference.rs b/nac3core/src/typecheck/expression_inference.rs index 81f6befb4..510b49b3a 100644 --- a/nac3core/src/typecheck/expression_inference.rs +++ b/nac3core/src/typecheck/expression_inference.rs @@ -8,7 +8,20 @@ use crate::typecheck::primitives; use rustpython_parser::ast; use rustpython_parser::ast::fold::Fold; -// REVIEW: direct impl fold trait on InferenceContext +struct Premapper; +impl ast::fold::Fold<()> for Premapper { + type TargetU = Option; + type Error = String; + + fn map_user(&mut self, _user: ()) -> Result { + Ok(None) + } + + fn fold_expr(&mut self, node: ast::Expr<()>) -> Result, Self::Error> { + ast::fold::fold_expr(self, node) + } +} + impl<'a> ast::fold::Fold> for InferenceContext<'a> { type TargetU = Option; type Error = String; @@ -43,6 +56,7 @@ impl<'a> ast::fold::Fold> for InferenceContext<'a> { ast::ExprKind::Subscript {value, slice, ctx: _} => self.infer_subscript(value, slice), ast::ExprKind::IfExp {test, body, orelse} => self.infer_if_expr(test, body, orelse), ast::ExprKind::ListComp {elt, generators} => self.infer_list_comprehesion(elt, generators), + ast::ExprKind::Slice { .. } => Ok(None), // special handling for slice, which is supported _ => Err("not supported yet".into()) }?, location: expr.location, @@ -408,32 +422,39 @@ impl<'a> InferenceContext<'a> { } } +pub struct ExpressionInferencer<'a> { + pub ctx: InferenceContext<'a> +} +impl<'a> ExpressionInferencer<'a> { + pub fn fold_expr(&mut self, expr: ast::Expr) -> Result>, String> { + let expr = Premapper.fold_expr(expr)?; + self.ctx.fold_expr(expr) + } +} + pub mod test { use crate::typecheck::{symbol_resolver::SymbolResolver, typedef::*, symbol_resolver::*, location::*}; use rustpython_parser::ast::{self, Expr, fold::Fold}; use super::*; - pub fn new_ctx<'a>() -> InferenceContext<'a>{ + pub fn new_ctx<'a>() -> ExpressionInferencer<'a> { struct S; - impl SymbolResolver for S { fn get_symbol_location(&self, _str: &str) -> Option { None } - fn get_symbol_type(&self, _str: &str) -> Option { None } - fn get_symbol_value(&self, _str: &str) -> Option { None } } - InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3)) + ExpressionInferencer {ctx: InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3))} } #[test] fn test_i32() { let mut inferencer = new_ctx(); - let ast: Expr> = Expr { + let ast: Expr = Expr { location: ast::Location::new(0, 0), - custom: None, + custom: (), node: ast::ExprKind::Constant { value: ast::Constant::Int(123.into()), kind: None @@ -445,7 +466,7 @@ pub mod test { new_ast, Ok(ast::Expr { location: ast::Location::new(0, 0), - custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)), + custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)), node: ast::ExprKind::Constant { value: ast::Constant::Int(123.into()), kind: None @@ -461,9 +482,9 @@ pub mod test { let location = ast::Location::new(0, 0); let num: i64 = 99999999999; - let ast: Expr> = Expr { + let ast: Expr = Expr { location: location, - custom: None, + custom: (), node: ast::ExprKind::Constant { value: ast::Constant::Int(num.into()), kind: None, @@ -476,7 +497,7 @@ pub mod test { new_ast, Expr { location: location, - custom: Some(inferencer.get_primitive(primitives::INT64_TYPE)), + custom: Some(inferencer.ctx.get_primitive(primitives::INT64_TYPE)), node: ast::ExprKind::Constant { value: ast::Constant::Int(num.into()), kind: None, @@ -485,20 +506,67 @@ pub mod test { ); } + #[test] + fn test_tuple() { + let mut inferencer = new_ctx(); + let i32_t = inferencer.ctx.get_primitive(primitives::INT32_TYPE); + let float_t = inferencer.ctx.get_primitive(primitives::FLOAT_TYPE); + let ast = rustpython_parser::parser::parse_expression("(123, 123.123, 999999999)").unwrap(); + let loc = ast.location.clone(); + let folded = inferencer.fold_expr(ast).unwrap(); + assert_eq!( + folded, + ast::Expr { + location: loc, + custom: Some(TypeEnum::ParametricType(primitives::TUPLE_TYPE, vec![i32_t.clone().into(), float_t.clone().into(), i32_t.clone().into()]).into()), + node: ast::ExprKind::Tuple { + ctx: ast::ExprContext::Load, + elts: vec![ + ast::Expr { + location: ast::Location::new(1, 2), + custom: Some(i32_t.clone()), + node: ast::ExprKind::Constant { + value: ast::Constant::Int(123.into()), + kind: None + } + }, + ast::Expr { + location: ast::Location::new(1, 7), + custom: Some(float_t.clone()), + node: ast::ExprKind::Constant { + value: ast::Constant::Float(123.123), + kind: None + } + }, + ast::Expr { + location: ast::Location::new(1, 16), + custom: Some(i32_t.clone()), + node: ast::ExprKind::Constant { + value: ast::Constant::Int(999999999.into()), + kind: None + } + }, + ] + } + } + ); + + } + #[test] fn test_list() { let mut inferencer = new_ctx(); let location = ast::Location::new(0, 0); - let ast: Expr> = Expr { + let ast: Expr = Expr { location, - custom: None, + custom: (), node: ast::ExprKind::List { ctx: ast::ExprContext::Load, elts: vec![ Expr { location, - custom: None, + custom: (), node: ast::ExprKind::Constant { value: ast::Constant::Int(1.into()), kind: None, @@ -507,7 +575,7 @@ pub mod test { Expr { location, - custom: None, + custom: (), node: ast::ExprKind::Constant { value: ast::Constant::Int(2.into()), kind: None, @@ -522,13 +590,13 @@ pub mod test { new_ast, Expr { location, - custom: Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![inferencer.get_primitive(primitives::INT32_TYPE).into()]).into()), + custom: Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![inferencer.ctx.get_primitive(primitives::INT32_TYPE).into()]).into()), node: ast::ExprKind::List { ctx: ast::ExprContext::Load, elts: vec![ Expr { location, - custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)), + custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)), node: ast::ExprKind::Constant { value: ast::Constant::Int(1.into()), kind: None, @@ -537,7 +605,7 @@ pub mod test { Expr { location, - custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)), + custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)), // custom: None, node: ast::ExprKind::Constant { value: ast::Constant::Int(2.into()), @@ -549,4 +617,42 @@ pub mod test { } ); } + + #[test] + fn test_mix() { + let mut inf = new_ctx(); + let ast1 = rustpython_parser::parser::parse_expression("False == [True or True, False][0]").unwrap(); + let ast2 = rustpython_parser::parser::parse_expression("False == [True or True, False][0]").unwrap(); + let ast3 = rustpython_parser::parser::parse_expression("1 < 2 < 3").unwrap(); + let ast4 = rustpython_parser::parser::parse_expression("1 + [12312, 1231][0]").unwrap(); + let ast5 = rustpython_parser::parser::parse_expression("not True").unwrap(); + let ast6 = rustpython_parser::parser::parse_expression("[[1]][0][0]").unwrap(); + let ast7 = rustpython_parser::parser::parse_expression("[[1]][0]").unwrap(); + let ast8 = rustpython_parser::parser::parse_expression("[[(1, 2), (2, 3), (3, 4)], [(2, 4), (4, 6)]][0]").unwrap(); + let ast9 = rustpython_parser::parser::parse_expression("[1, 2, 3, 4, 5][1: 2]").unwrap(); + let ast10 = rustpython_parser::parser::parse_expression("4 if False and True else 8").unwrap(); + + let folded = inf.fold_expr(ast1).unwrap(); + let folded_2 = Premapper.fold_expr(ast2).unwrap(); + let folded_3 = inf.fold_expr(ast3).unwrap(); + let folded_4 = inf.fold_expr(ast4).unwrap(); + let folded_5 = inf.fold_expr(ast5).unwrap(); + let folded_6 = inf.fold_expr(ast6).unwrap(); + let folded_7 = inf.fold_expr(ast7).unwrap(); + let folded_8 = inf.fold_expr(ast8).unwrap(); + let folded_9 = inf.fold_expr(ast9).unwrap(); + let folded_10 = inf.fold_expr(ast10).unwrap(); + + println!("{:?}", folded.custom); + println!("{:?}", folded_2.custom); + println!("{:?}", folded_3.custom); + println!("{:?}", folded_4.custom); + println!("{:?}", folded_5.custom); + println!("{:?}", folded_6.custom); + println!("{:?}", folded_7.custom); + println!("{:?}", folded_8.custom); + println!("{:?}", folded_9.custom); + println!("{:?}", folded_10.custom); + + } } \ No newline at end of file