From 44f4b7cfc754d8f7c71263b08dd53e9657ad9fcd Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 4 Jan 2021 13:25:10 +0800 Subject: [PATCH] enabled tests and applied clippy --- nac3core/src/expression_inference.rs | 1287 +++++++++++++------------- 1 file changed, 628 insertions(+), 659 deletions(-) diff --git a/nac3core/src/expression_inference.rs b/nac3core/src/expression_inference.rs index 90a02026..8ec37303 100644 --- a/nac3core/src/expression_inference.rs +++ b/nac3core/src/expression_inference.rs @@ -1,8 +1,8 @@ +use crate::context::InferenceContext; use crate::inference_core::resolve_call; use crate::magic_methods::*; use crate::primitives::*; use crate::typedef::{Type, TypeEnum::*}; -use crate::context::InferenceContext; use rustpython_parser::ast::{ Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator, UnaryOperator, @@ -11,7 +11,10 @@ use std::convert::TryInto; type ParserResult = Result, String>; -pub fn infer_expr<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, expr: &'b Expression) -> ParserResult { +pub fn infer_expr<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + expr: &'b Expression, +) -> ParserResult { match &expr.node { ExpressionType::Number { value } => infer_constant(ctx, value), ExpressionType::Identifier { name } => infer_identifier(ctx, name), @@ -27,7 +30,7 @@ pub fn infer_expr<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, expr: &'b Expressi function, keywords, } => { - if keywords.is_empty() { + if !keywords.is_empty() { Err("keyword is not supported".into()) } else { infer_call(ctx, &args, &function) @@ -80,7 +83,10 @@ fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult { Ok(Some(ctx.resolve(name)?)) } -fn infer_list<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, elements: &'b [Expression]) -> ParserResult { +fn infer_list<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + elements: &'b [Expression], +) -> ParserResult { if elements.is_empty() { return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into())); } @@ -99,11 +105,12 @@ fn infer_list<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, elements: &'b [Express Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into())) } -fn infer_tuple<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, elements: &'b [Expression]) -> ParserResult { - let types: Result>, String> = elements - .iter() - .map(|v| infer_expr(ctx, v)) - .collect(); +fn infer_tuple<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + elements: &'b [Expression], +) -> ParserResult { + let types: Result>, String> = + elements.iter().map(|v| infer_expr(ctx, v)).collect(); if let Some(t) = types? { Ok(Some(ParametricType(TUPLE_TYPE, t).into())) } else { @@ -122,9 +129,7 @@ fn infer_attribute<'a>( if v.bound.is_empty() { return Err("no fields on unbounded type variable".into()); } - let ty = v.bound[0] - .get_base(ctx) - .and_then(|v| v.fields.get(name)); + let ty = v.bound[0].get_base(ctx).and_then(|v| v.fields.get(name)); if ty.is_none() { return Err("unknown field".into()); } @@ -146,10 +151,7 @@ fn infer_attribute<'a>( } } -fn infer_bool_ops<'a>( - ctx: &mut InferenceContext<'a>, - values: &'a [Expression], -) -> ParserResult { +fn infer_bool_ops<'a>(ctx: &mut InferenceContext<'a>, values: &'a [Expression]) -> ParserResult { assert_eq!(values.len(), 2); let left = infer_expr(ctx, &values[0])?.ok_or_else(|| "no value".to_string())?; let right = infer_expr(ctx, &values[1])?.ok_or_else(|| "no value".to_string())?; @@ -196,8 +198,7 @@ fn infer_compare<'b: 'a, 'a>( vals: &'b [Expression], ops: &'b [Comparison], ) -> ParserResult { - let types: Result>, _> = - vals.iter().map(|v| infer_expr(ctx, v)).collect(); + let types: Result>, _> = vals.iter().map(|v| infer_expr(ctx, v)).collect(); let types = types?; if types.is_none() { return Err("comparison operands must have type".into()); @@ -222,8 +223,7 @@ fn infer_call<'b: 'a, 'a>( args: &'b [Expression], function: &'b Expression, ) -> ParserResult { - let types: Result>, _> = - args.iter().map(|v| infer_expr(ctx, v)).collect(); + let types: Result>, _> = args.iter().map(|v| infer_expr(ctx, v)).collect(); let types = types?; if types.is_none() { return Err("function params must have type".into()); @@ -361,647 +361,616 @@ fn infer_list_comprehension<'b: 'a, 'a>( } let result = infer_expr(ctx, element)?.ok_or_else(|| "no value")?; Ok(Some(ParametricType(LIST_TYPE, vec![result]).into())) - }).1 + }) + .1 } else { Err("iteration is supported for list only".into()) } } -// #[cfg(test)] -// mod test { -// use super::*; -// use crate::typedef::*; -// use rustpython_parser::parser::parse_expression; - -// #[test] -// fn test_constants() { -// let ctx = basic_ctx(); -// let sym_table = HashMap::new(); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("123").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("2147483647").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("2147483648").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("9223372036854775807").unwrap(), -// ); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("9223372036854775808").unwrap(), -// ); -// assert_eq!(result, Err("integer out of range".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("123.456").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(FLOAT_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("True").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("False").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); -// } - -// #[test] -// fn test_identifier() { -// let ctx = basic_ctx(); -// let mut sym_table = HashMap::new(); -// sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("abc").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("ab").unwrap()); -// assert_eq!(result, Err("unbounded variable".into())); -// } - -// #[test] -// fn test_list() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "foo", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); -// let mut sym_table = HashMap::new(); -// sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); -// // def is reserved... -// sym_table.insert("efg", Rc::new(PrimitiveType(INT32_TYPE))); -// sym_table.insert("xyz", Rc::new(PrimitiveType(FLOAT_TYPE))); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("[]").unwrap()); -// assert_eq!( -// result.unwrap().unwrap(), -// ParametricType(LIST_TYPE, vec![BotType.into()]).into() -// ); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("[abc]").unwrap()); -// assert_eq!( -// result.unwrap().unwrap(), -// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() -// ); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("[abc, efg]").unwrap()); -// assert_eq!( -// result.unwrap().unwrap(), -// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() -// ); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[abc, efg, xyz]").unwrap(), -// ); -// assert_eq!(result, Err("inhomogeneous list is not allowed".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("[foo()]").unwrap()); -// assert_eq!(result, Err("list elements must have some type".into())); -// } - -// #[test] -// fn test_tuple() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "foo", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); -// let mut sym_table = HashMap::new(); -// sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); -// sym_table.insert("efg", Rc::new(PrimitiveType(FLOAT_TYPE))); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("(abc, efg)").unwrap()); -// assert_eq!( -// result.unwrap().unwrap(), -// ParametricType( -// TUPLE_TYPE, -// vec![ -// PrimitiveType(INT32_TYPE).into(), -// PrimitiveType(FLOAT_TYPE).into() -// ] -// ) -// .into() -// ); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("(abc, efg, foo())").unwrap(), -// ); -// assert_eq!(result, Err("tuple elements must have some type".into())); -// } - -// #[test] -// fn test_attribute() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "none", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); - -// let foo = ctx.add_class(ClassDef { -// base: TypeDef { -// name: "Foo", -// fields: HashMap::new(), -// methods: HashMap::new(), -// }, -// parents: vec![], -// }); -// let foo_def = ctx.get_class_mut(foo); -// foo_def -// .base -// .fields -// .insert("a", PrimitiveType(INT32_TYPE).into()); -// foo_def.base.fields.insert("b", ClassType(foo).into()); -// foo_def -// .base -// .fields -// .insert("c", PrimitiveType(INT32_TYPE).into()); - -// let bar = ctx.add_class(ClassDef { -// base: TypeDef { -// name: "Bar", -// fields: HashMap::new(), -// methods: HashMap::new(), -// }, -// parents: vec![], -// }); -// let bar_def = ctx.get_class_mut(bar); -// bar_def -// .base -// .fields -// .insert("a", PrimitiveType(INT32_TYPE).into()); -// bar_def.base.fields.insert("b", ClassType(bar).into()); -// bar_def -// .base -// .fields -// .insert("c", PrimitiveType(FLOAT_TYPE).into()); - -// let v0 = ctx.add_variable(VarDef { -// name: "v0", -// bound: vec![], -// }); - -// let v1 = ctx.add_variable(VarDef { -// name: "v1", -// bound: vec![ClassType(foo).into(), ClassType(bar).into()], -// }); - -// let mut sym_table = HashMap::new(); -// sym_table.insert("foo", Rc::new(ClassType(foo))); -// sym_table.insert("bar", Rc::new(ClassType(bar))); -// sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); -// sym_table.insert("v0", Rc::new(TypeVariable(v0))); -// sym_table.insert("v1", Rc::new(TypeVariable(v1))); -// sym_table.insert("bot", Rc::new(BotType)); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.a").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.d").unwrap()); -// assert_eq!(result, Err("no such field".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("foobar.a").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v0.a").unwrap()); -// assert_eq!(result, Err("no fields on unbounded type variable".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.a").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// // shall we support this? -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.b").unwrap()); -// assert_eq!( -// result, -// Err("unknown field (type mismatch between variants)".into()) -// ); -// // assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.c").unwrap()); -// assert_eq!( -// result, -// Err("unknown field (type mismatch between variants)".into()) -// ); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.d").unwrap()); -// assert_eq!(result, Err("unknown field".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("none().a").unwrap()); -// assert_eq!(result, Err("no value".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("bot.a").unwrap()); -// assert_eq!(result, Err("this object has no fields".into())); -// } - -// #[test] -// fn test_bool_ops() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "none", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); -// let sym_table = HashMap::new(); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("True and False").unwrap(), -// ); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("True and none()").unwrap(), -// ); -// assert_eq!(result, Err("no value".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("True and 123").unwrap()); -// assert_eq!(result, Err("bool operands must be bool".into())); -// } - -// #[test] -// fn test_bin_ops() { -// let mut ctx = basic_ctx(); -// let v0 = ctx.add_variable(VarDef { -// name: "v0", -// bound: vec![ -// PrimitiveType(INT32_TYPE).into(), -// PrimitiveType(INT64_TYPE).into(), -// ], -// }); -// let mut sym_table = HashMap::new(); -// sym_table.insert("a", TypeVariable(v0).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("1 + 2 + 3").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("a + a + a").unwrap()); -// assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); -// } - -// #[test] -// fn test_unary_ops() { -// let mut ctx = basic_ctx(); -// let v0 = ctx.add_variable(VarDef { -// name: "v0", -// bound: vec![ -// PrimitiveType(INT32_TYPE).into(), -// PrimitiveType(INT64_TYPE).into(), -// ], -// }); -// let mut sym_table = HashMap::new(); -// sym_table.insert("a", TypeVariable(v0).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("-(123)").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("-a").unwrap()); -// assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("not True").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("not (1)").unwrap()); -// assert_eq!(result, Err("logical not must be applied to bool".into())); -// } - -// #[test] -// fn test_compare() { -// let mut ctx = basic_ctx(); -// let v0 = ctx.add_variable(VarDef { -// name: "v0", -// bound: vec![ -// PrimitiveType(INT32_TYPE).into(), -// PrimitiveType(INT64_TYPE).into(), -// ], -// }); -// let mut sym_table = HashMap::new(); -// sym_table.insert("a", TypeVariable(v0).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("a == a == a").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("a == a == 1").unwrap()); -// assert_eq!(result, Err("not equal".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("True > False").unwrap()); -// assert_eq!(result, Err("no such function".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("True in False").unwrap(), -// ); -// assert_eq!(result, Err("unsupported comparison".into())); -// } - -// #[test] -// fn test_call() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "none", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); - -// let foo = ctx.add_class(ClassDef { -// base: TypeDef { -// name: "Foo", -// fields: HashMap::new(), -// methods: HashMap::new(), -// }, -// parents: vec![], -// }); -// let foo_def = ctx.get_class_mut(foo); -// foo_def.base.methods.insert( -// "a", -// FnDef { -// args: vec![], -// result: Some(Rc::new(ClassType(foo))), -// }, -// ); - -// let bar = ctx.add_class(ClassDef { -// base: TypeDef { -// name: "Bar", -// fields: HashMap::new(), -// methods: HashMap::new(), -// }, -// parents: vec![], -// }); -// let bar_def = ctx.get_class_mut(bar); -// bar_def.base.methods.insert( -// "a", -// FnDef { -// args: vec![], -// result: Some(Rc::new(ClassType(bar))), -// }, -// ); - -// let v0 = ctx.add_variable(VarDef { -// name: "v0", -// bound: vec![], -// }); -// let v1 = ctx.add_variable(VarDef { -// name: "v1", -// bound: vec![ClassType(foo).into(), ClassType(bar).into()], -// }); -// let v2 = ctx.add_variable(VarDef { -// name: "v2", -// bound: vec![ -// ClassType(foo).into(), -// ClassType(bar).into(), -// PrimitiveType(INT32_TYPE).into(), -// ], -// }); -// let mut sym_table = HashMap::new(); -// sym_table.insert("foo", Rc::new(ClassType(foo))); -// sym_table.insert("bar", Rc::new(ClassType(bar))); -// sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); -// sym_table.insert("v0", Rc::new(TypeVariable(v0))); -// sym_table.insert("v1", Rc::new(TypeVariable(v1))); -// sym_table.insert("v2", Rc::new(TypeVariable(v2))); -// sym_table.insert("bot", Rc::new(BotType)); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.a()").unwrap()); -// assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.a()").unwrap()); -// assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("foobar.a()").unwrap()); -// assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("none().a()").unwrap()); -// assert_eq!(result, Err("no value".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("bot.a()").unwrap()); -// assert_eq!(result, Err("not supported".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("[][0].a()").unwrap()); -// assert_eq!(result, Err("not supported".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v0.a()").unwrap()); -// assert_eq!(result, Err("unbounded type var".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("v2.a()").unwrap()); -// assert_eq!(result, Err("no such function".into())); -// } - -// #[test] -// fn infer_subscript() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "none", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); -// let sym_table = HashMap::new(); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("[1, 2, 3][0]").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("[[1]][0][0]").unwrap()); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[1, 2, 3][1:2]").unwrap(), -// ); -// assert_eq!( -// result.unwrap().unwrap(), -// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() -// ); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[1, 2, 3][1:2:2]").unwrap(), -// ); -// assert_eq!( -// result.unwrap().unwrap(), -// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() -// ); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[1, 2, 3][1:1.2]").unwrap(), -// ); -// assert_eq!(result, Err("slice must be int32 type".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[1, 2, 3][1:none()]").unwrap(), -// ); -// assert_eq!(result, Err("slice must have type".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[1, 2, 3][1.2]").unwrap(), -// ); -// assert_eq!(result, Err("index must be either slice or int32".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[1, 2, 3][none()]").unwrap(), -// ); -// assert_eq!(result, Err("no value".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("none()[1.2]").unwrap()); -// assert_eq!(result, Err("no value".into())); - -// let result = infer_expr(&ctx, &sym_table, &parse_expression("123[1]").unwrap()); -// assert_eq!( -// result, -// Err("subscript is not supported for types other than list".into()) -// ); -// } - -// #[test] -// fn test_if_expr() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "none", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); -// let sym_table = HashMap::new(); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("1 if True else 0").unwrap(), -// ); -// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("none() if True else none()").unwrap(), -// ); -// assert_eq!(result.unwrap(), None); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("none() if 1 else none()").unwrap(), -// ); -// assert_eq!(result, Err("test should be bool".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("1 if True else none()").unwrap(), -// ); -// assert_eq!(result, Err("divergent type".into())); -// } - -// #[test] -// fn test_list_comp() { -// let mut ctx = basic_ctx(); -// ctx.add_fn( -// "none", -// FnDef { -// args: vec![], -// result: None, -// }, -// ); -// let int32 = Rc::new(PrimitiveType(INT32_TYPE)); -// let mut sym_table = HashMap::new(); -// sym_table.insert("z", int32.clone()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), -// ); -// assert_eq!( -// result.unwrap().unwrap(), -// ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into() -// ); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), -// ); -// assert_eq!(result.unwrap().unwrap(), int32.clone()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap(), -// ); -// assert_eq!(result.unwrap().unwrap(), int32.clone()); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap(), -// ); -// assert_eq!(result, Err("test must be bool".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[y for x in []][0]").unwrap(), -// ); -// assert_eq!(result, Err("unbounded variable".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[none() for x in []][0]").unwrap(), -// ); -// assert_eq!(result, Err("no value".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[z for z in []][0]").unwrap(), -// ); -// assert_eq!(result, Err("duplicated naming".into())); - -// let result = infer_expr( -// &ctx, -// &sym_table, -// &parse_expression("[x for x in [] for y in []]").unwrap(), -// ); -// assert_eq!( -// result, -// Err("only 1 generator statement is supported".into()) -// ); -// } -// } +#[cfg(test)] +mod test { + use super::*; + use crate::context::*; + use crate::typedef::*; + use rustpython_parser::parser::parse_expression; + use std::collections::HashMap; + use std::rc::Rc; + + fn get_inference_context(ctx: TopLevelContext) -> InferenceContext { + InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into()))) + } + + #[test] + fn test_constants() { + let ctx = basic_ctx(); + let mut ctx = get_inference_context(ctx); + + let ast = parse_expression("123").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("2147483647").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("2147483648").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE)); + + let ast = parse_expression("9223372036854775807").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT64_TYPE)); + + let ast = parse_expression("9223372036854775808").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("integer out of range".into())); + + let ast = parse_expression("123.456").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(FLOAT_TYPE)); + + let ast = parse_expression("True").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); + + let ast = parse_expression("False").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); + } + + #[test] + fn test_identifier() { + let ctx = basic_ctx(); + let mut ctx = get_inference_context(ctx); + ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap(); + + let ast = parse_expression("abc").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("ab").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("unbounded identifier".into())); + } + + #[test] + fn test_list() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "foo", + FnDef { + args: vec![], + result: None, + }, + ); + let mut ctx = get_inference_context(ctx); + ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap(); + // def is reserved... + ctx.assign("efg", ctx.get_primitive(INT32_TYPE)).unwrap(); + ctx.assign("xyz", ctx.get_primitive(FLOAT_TYPE)).unwrap(); + + let ast = parse_expression("[]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![BotType.into()]).into() + ); + + let ast = parse_expression("[abc]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() + ); + + let ast = parse_expression("[abc, efg]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() + ); + + let ast = parse_expression("[abc, efg, xyz]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("inhomogeneous list is not allowed".into())); + + let ast = parse_expression("[foo()]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("list elements must have some type".into())); + } + + #[test] + fn test_tuple() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "foo", + FnDef { + args: vec![], + result: None, + }, + ); + let mut ctx = get_inference_context(ctx); + ctx.assign("abc", ctx.get_primitive(INT32_TYPE)).unwrap(); + ctx.assign("efg", ctx.get_primitive(FLOAT_TYPE)).unwrap(); + + let ast = parse_expression("(abc, efg)").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result.unwrap().unwrap(), + ParametricType( + TUPLE_TYPE, + vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)] + ) + .into() + ); + + let ast = parse_expression("(abc, efg, foo())").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("tuple elements must have some type".into())); + } + + #[test] + fn test_attribute() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let int32 = ctx.get_primitive(INT32_TYPE); + let float = ctx.get_primitive(FLOAT_TYPE); + + let foo = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let foo_def = ctx.get_class_def_mut(foo); + foo_def + .base + .fields + .insert("a", int32.clone()); + foo_def.base.fields.insert("b", ClassType(foo).into()); + foo_def + .base + .fields + .insert("c", int32.clone()); + + let bar = ctx.add_class(ClassDef { + base: TypeDef { + name: "Bar", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let bar_def = ctx.get_class_def_mut(bar); + bar_def + .base + .fields + .insert("a", int32); + bar_def.base.fields.insert("b", ClassType(bar).into()); + bar_def + .base + .fields + .insert("c", float); + + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![], + }); + + let v1 = ctx.add_variable(VarDef { + name: "v1", + bound: vec![ClassType(foo).into(), ClassType(bar).into()], + }); + + let mut ctx = get_inference_context(ctx); + ctx.assign("foo", Rc::new(ClassType(foo))).unwrap(); + ctx.assign("bar", Rc::new(ClassType(bar))).unwrap(); + ctx.assign("foobar", Rc::new(VirtualClassType(foo))).unwrap(); + ctx.assign("v0", ctx.get_variable(v0)).unwrap(); + ctx.assign("v1", ctx.get_variable(v1)).unwrap(); + ctx.assign("bot", Rc::new(BotType)).unwrap(); + + let ast = parse_expression("foo.a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("foo.d").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no such field".into())); + + let ast = parse_expression("foobar.a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("v0.a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no fields on unbounded type variable".into())); + + let ast = parse_expression("v1.a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + // shall we support this? + let ast = parse_expression("v1.b").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result, + Err("unknown field (type mismatch between variants)".into()) + ); + // assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); + + let ast = parse_expression("v1.c").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result, + Err("unknown field (type mismatch between variants)".into()) + ); + + let ast = parse_expression("v1.d").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("unknown field".into())); + + let ast = parse_expression("none().a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no value".into())); + + let ast = parse_expression("bot.a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("this object has no fields".into())); + } + + #[test] + fn test_bool_ops() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let mut ctx = get_inference_context(ctx); + + let ast = parse_expression("True and False").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); + + let ast = parse_expression("True and none()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no value".into())); + + let ast = parse_expression("True and 123").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("bool operands must be bool".into())); + } + + #[test] + fn test_bin_ops() { + let mut ctx = basic_ctx(); + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)], + }); + let mut ctx = get_inference_context(ctx); + ctx.assign("a", TypeVariable(v0).into()).unwrap(); + + let ast = parse_expression("1 + 2 + 3").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("a + a + a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); + } + + #[test] + fn test_unary_ops() { + let mut ctx = basic_ctx(); + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)], + }); + let mut ctx = get_inference_context(ctx); + ctx.assign("a", TypeVariable(v0).into()).unwrap(); + + let ast = parse_expression("-(123)").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("-a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); + + let ast = parse_expression("not True").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); + + let ast = parse_expression("not (1)").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("logical not must be applied to bool".into())); + } + + #[test] + fn test_compare() { + let mut ctx = basic_ctx(); + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(INT64_TYPE)], + }); + let mut ctx = get_inference_context(ctx); + ctx.assign("a", TypeVariable(v0).into()).unwrap(); + + let ast = parse_expression("a == a == a").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(BOOL_TYPE)); + + let ast = parse_expression("a == a == 1").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("not equal".into())); + + let ast = parse_expression("True > False").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no such function".into())); + + let ast = parse_expression("True in False").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("unsupported comparison".into())); + } + + #[test] + fn test_call() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + + let foo = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let foo_def = ctx.get_class_def_mut(foo); + foo_def.base.methods.insert( + "a", + FnDef { + args: vec![], + result: Some(Rc::new(ClassType(foo))), + }, + ); + + let bar = ctx.add_class(ClassDef { + base: TypeDef { + name: "Bar", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let bar_def = ctx.get_class_def_mut(bar); + bar_def.base.methods.insert( + "a", + FnDef { + args: vec![], + result: Some(Rc::new(ClassType(bar))), + }, + ); + + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![], + }); + let v1 = ctx.add_variable(VarDef { + name: "v1", + bound: vec![ClassType(foo).into(), ClassType(bar).into()], + }); + let v2 = ctx.add_variable(VarDef { + name: "v2", + bound: vec![ + ClassType(foo).into(), + ClassType(bar).into(), + ctx.get_primitive(INT32_TYPE), + ], + }); + let mut ctx = get_inference_context(ctx); + ctx.assign("foo", Rc::new(ClassType(foo))).unwrap(); + ctx.assign("bar", Rc::new(ClassType(bar))).unwrap(); + ctx.assign("foobar", Rc::new(VirtualClassType(foo))).unwrap(); + ctx.assign("v0", ctx.get_variable(v0)).unwrap(); + ctx.assign("v1", ctx.get_variable(v1)).unwrap(); + ctx.assign("v2", ctx.get_variable(v2)).unwrap(); + ctx.assign("bot", Rc::new(BotType)).unwrap(); + + let ast = parse_expression("foo.a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); + + let ast = parse_expression("v1.a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); + + let ast = parse_expression("foobar.a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); + + let ast = parse_expression("none().a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no value".into())); + + let ast = parse_expression("bot.a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("not supported".into())); + + let ast = parse_expression("[][0].a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("not supported".into())); + + let ast = parse_expression("v0.a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("unbounded type var".into())); + + let ast = parse_expression("v2.a()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no such function".into())); + } + + #[test] + fn infer_subscript() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let mut ctx = get_inference_context(ctx); + + let ast = parse_expression("[1, 2, 3][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("[[1]][0][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("[1, 2, 3][1:2]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() + ); + + let ast = parse_expression("[1, 2, 3][1:2:2]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![ctx.get_primitive(INT32_TYPE)]).into() + ); + + let ast = parse_expression("[1, 2, 3][1:1.2]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("slice must be int32 type".into())); + + let ast = parse_expression("[1, 2, 3][1:none()]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("slice must have type".into())); + + let ast = parse_expression("[1, 2, 3][1.2]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("index must be either slice or int32".into())); + + let ast = parse_expression("[1, 2, 3][none()]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no value".into())); + + let ast = parse_expression("none()[1.2]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no value".into())); + + let ast = parse_expression("123[1]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result, + Err("subscript is not supported for types other than list".into()) + ); + } + + #[test] + fn test_if_expr() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let mut ctx = get_inference_context(ctx); + + let ast = parse_expression("1 if True else 0").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), ctx.get_primitive(INT32_TYPE)); + + let ast = parse_expression("none() if True else none()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap(), None); + + let ast = parse_expression("none() if 1 else none()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("test should be bool".into())); + + let ast = parse_expression("1 if True else none()").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("divergent type".into())); + } + + #[test] + fn test_list_comp() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let int32 = ctx.get_primitive(INT32_TYPE); + let mut ctx = get_inference_context(ctx); + ctx.assign("z", int32.clone()).unwrap(); + + let ast = parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into() + ); + + let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), int32); + + let ast = + parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result.unwrap().unwrap(), int32); + + let ast = parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("test must be bool".into())); + + let ast = parse_expression("[y for x in []][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("unbounded identifier".into())); + + let ast = parse_expression("[none() for x in []][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("no value".into())); + + let ast = parse_expression("[z for z in []][0]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!(result, Err("duplicated naming".into())); + + let ast = parse_expression("[x for x in [] for y in []]").unwrap(); + let result = infer_expr(&mut ctx, &ast); + assert_eq!( + result, + Err("only 1 generator statement is supported".into()) + ); + } +}