From 5f0490cd84e9c86a6b82fc42ff72e9ddb8f7c157 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Tue, 27 Jul 2021 11:58:35 +0800 Subject: [PATCH] added virtual test --- nac3core/src/typecheck/type_inferencer/mod.rs | 4 + .../src/typecheck/type_inferencer/test.rs | 104 +++++++++++++++--- 2 files changed, 91 insertions(+), 17 deletions(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8e5aec2a0..5eecb1829 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -28,6 +28,7 @@ pub struct PrimitiveStore { pub struct Inferencer<'a> { pub resolver: &'a mut Box, pub unifier: &'a mut Unifier, + pub virtual_checks: &'a mut Vec<(Type, Type)>, pub variable_mapping: HashMap, pub calls: &'a mut Vec>, pub primitives: &'a PrimitiveStore, @@ -208,6 +209,7 @@ impl<'a> Inferencer<'a> { let mut new_context = Inferencer { resolver: self.resolver, unifier: self.unifier, + virtual_checks: self.virtual_checks, variable_mapping, calls: self.calls, primitives: self.primitives, @@ -250,6 +252,7 @@ impl<'a> Inferencer<'a> { let mut new_context = Inferencer { resolver: self.resolver, unifier: self.unifier, + virtual_checks: self.virtual_checks, variable_mapping, calls: self.calls, primitives: self.primitives, @@ -318,6 +321,7 @@ impl<'a> Inferencer<'a> { } else { self.unifier.get_fresh_var().0 }; + self.virtual_checks.push((arg0.custom.unwrap(), ty)); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); return Ok(Located { location, diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index be9b1506f..598aedf9c 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -3,12 +3,14 @@ use super::super::symbol_resolver::*; use super::super::typedef::*; use super::*; use indoc::indoc; +use itertools::zip; use rustpython_parser::ast; use rustpython_parser::parser::parse_program; use test_case::test_case; struct Resolver { identifier_mapping: HashMap, + class_names: HashMap, } impl SymbolResolver for Resolver { @@ -16,8 +18,12 @@ impl SymbolResolver for Resolver { self.identifier_mapping.get(str).cloned() } - fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option { - unimplemented!() + fn parse_type_name(&mut self, ty: &ast::Expr<()>) -> Option { + if let ExprKind::Name { id, .. } = &ty.node { + self.class_names.get(id).cloned() + } else { + unimplemented!() + } } fn get_symbol_value(&mut self, _: &str) -> Option { @@ -36,6 +42,7 @@ struct TestEnvironment { pub primitives: PrimitiveStore, pub id_to_name: HashMap, pub identifier_mapping: HashMap, + pub virtual_checks: Vec<(Type, Type)>, } impl TestEnvironment { @@ -69,13 +76,7 @@ impl TestEnvironment { }); identifier_mapping.insert("None".into(), none); - let primitives = PrimitiveStore { - int32, - int64, - float, - bool, - none, - }; + let primitives = PrimitiveStore { int32, int64, float, bool, none }; let (v0, id) = unifier.get_fresh_var(); @@ -94,6 +95,40 @@ impl TestEnvironment { })), ); + let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: int32, + vars: Default::default(), + })); + let bar = unifier.add_ty(TypeEnum::TObj { + obj_id: 6, + fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect(), + params: Default::default(), + }); + identifier_mapping.insert( + "Bar".into(), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: bar, + vars: Default::default(), + })), + ); + + let bar2 = unifier.add_ty(TypeEnum::TObj { + obj_id: 7, + fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect(), + params: Default::default(), + }); + identifier_mapping.insert( + "Bar2".into(), + unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: bar2, + vars: Default::default(), + })), + ); + let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); + let id_to_name = [ (0, "int32".to_string()), (1, "int64".to_string()), @@ -101,12 +136,16 @@ impl TestEnvironment { (3, "bool".to_string()), (4, "none".to_string()), (5, "Foo".to_string()), + (6, "Bar".to_string()), + (7, "Bar2".to_string()), ] .iter() .cloned() .collect(); - let resolver = Box::new(Resolver { identifier_mapping: identifier_mapping.clone() }) as Box; + let resolver = + Box::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names }) + as Box; TestEnvironment { unifier, @@ -115,6 +154,7 @@ impl TestEnvironment { id_to_name, identifier_mapping, calls: Vec::new(), + virtual_checks: Vec::new(), } } @@ -125,7 +165,8 @@ impl TestEnvironment { variable_mapping: Default::default(), calls: &mut self.calls, primitives: &mut self.primitives, - return_type: None + virtual_checks: &mut self.virtual_checks, + return_type: None, } } } @@ -136,7 +177,8 @@ impl TestEnvironment { c = 1.234 d = True "}, - [("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect() + [("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(), + &[] ; "primitives test")] #[test_case(indoc! {" a = lambda x, y: x @@ -144,7 +186,8 @@ impl TestEnvironment { c = 1.234 d = b(c) "}, - [("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect() + [("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(), + &[] ; "lambda test")] #[test_case(indoc! {" a = lambda x: x @@ -160,20 +203,31 @@ impl TestEnvironment { "}, [("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"), - ("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect() + ("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(), + &[] ; "obj test")] #[test_case(indoc! {" f = lambda x: True a = [1, 2, 3] b = [f(x) for x in a if f(x)] "}, - [("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect() + [("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(), + &[] ; "listcomp test")] -fn test_basic(source: &str, mapping: HashMap<&str, &str>) { +#[test_case(indoc! {" + a = virtual(Bar(), Bar) + b = a.b() + a = virtual(Bar2()) + "}, + [("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(), + &[("Bar", "Bar"), ("Bar2", "Bar")] + ; "virtual test")] +fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) { println!("source:\n{}", source); let mut env = TestEnvironment::new(); let id_to_name = std::mem::take(&mut env.id_to_name); - let mut defined_identifiers = env.identifier_mapping.keys().cloned().collect(); + let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); + defined_identifiers.push("virtual".to_string()); let mut inferencer = env.get_inferencer(); let statements = parse_program(source).unwrap(); let statements = statements @@ -201,4 +255,20 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>) { ); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); } + assert_eq!(inferencer.virtual_checks.len(), virtuals.len()); + for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) { + let a = inferencer.unifier.stringify( + *a, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + let b = inferencer.unifier.stringify( + *b, + &mut |v| id_to_name.get(&v).unwrap().clone(), + &mut |v| format!("v{}", v), + ); + + assert_eq!(&a, x); + assert_eq!(&b, y); + } }