added virtual test

This commit is contained in:
pca006132 2021-07-27 11:58:35 +08:00
parent 1d13b16f94
commit 5f0490cd84
2 changed files with 91 additions and 17 deletions

View File

@ -28,6 +28,7 @@ pub struct PrimitiveStore {
pub struct Inferencer<'a> { pub struct Inferencer<'a> {
pub resolver: &'a mut Box<dyn SymbolResolver>, pub resolver: &'a mut Box<dyn SymbolResolver>,
pub unifier: &'a mut Unifier, pub unifier: &'a mut Unifier,
pub virtual_checks: &'a mut Vec<(Type, Type)>,
pub variable_mapping: HashMap<String, Type>, pub variable_mapping: HashMap<String, Type>,
pub calls: &'a mut Vec<Rc<Call>>, pub calls: &'a mut Vec<Rc<Call>>,
pub primitives: &'a PrimitiveStore, pub primitives: &'a PrimitiveStore,
@ -208,6 +209,7 @@ impl<'a> Inferencer<'a> {
let mut new_context = Inferencer { let mut new_context = Inferencer {
resolver: self.resolver, resolver: self.resolver,
unifier: self.unifier, unifier: self.unifier,
virtual_checks: self.virtual_checks,
variable_mapping, variable_mapping,
calls: self.calls, calls: self.calls,
primitives: self.primitives, primitives: self.primitives,
@ -250,6 +252,7 @@ impl<'a> Inferencer<'a> {
let mut new_context = Inferencer { let mut new_context = Inferencer {
resolver: self.resolver, resolver: self.resolver,
unifier: self.unifier, unifier: self.unifier,
virtual_checks: self.virtual_checks,
variable_mapping, variable_mapping,
calls: self.calls, calls: self.calls,
primitives: self.primitives, primitives: self.primitives,
@ -318,6 +321,7 @@ impl<'a> Inferencer<'a> {
} else { } else {
self.unifier.get_fresh_var().0 self.unifier.get_fresh_var().0
}; };
self.virtual_checks.push((arg0.custom.unwrap(), ty));
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
return Ok(Located { return Ok(Located {
location, location,

View File

@ -3,12 +3,14 @@ use super::super::symbol_resolver::*;
use super::super::typedef::*; use super::super::typedef::*;
use super::*; use super::*;
use indoc::indoc; use indoc::indoc;
use itertools::zip;
use rustpython_parser::ast; use rustpython_parser::ast;
use rustpython_parser::parser::parse_program; use rustpython_parser::parser::parse_program;
use test_case::test_case; use test_case::test_case;
struct Resolver { struct Resolver {
identifier_mapping: HashMap<String, Type>, identifier_mapping: HashMap<String, Type>,
class_names: HashMap<String, Type>,
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
@ -16,8 +18,12 @@ impl SymbolResolver for Resolver {
self.identifier_mapping.get(str).cloned() self.identifier_mapping.get(str).cloned()
} }
fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option<Type> { fn parse_type_name(&mut self, ty: &ast::Expr<()>) -> Option<Type> {
unimplemented!() if let ExprKind::Name { id, .. } = &ty.node {
self.class_names.get(id).cloned()
} else {
unimplemented!()
}
} }
fn get_symbol_value(&mut self, _: &str) -> Option<SymbolValue> { fn get_symbol_value(&mut self, _: &str) -> Option<SymbolValue> {
@ -36,6 +42,7 @@ struct TestEnvironment {
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>, pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>, pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
} }
impl TestEnvironment { impl TestEnvironment {
@ -69,13 +76,7 @@ impl TestEnvironment {
}); });
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
let primitives = PrimitiveStore { let primitives = PrimitiveStore { int32, int64, float, bool, none };
int32,
int64,
float,
bool,
none,
};
let (v0, id) = unifier.get_fresh_var(); let (v0, id) = unifier.get_fresh_var();
@ -94,6 +95,40 @@ impl TestEnvironment {
})), })),
); );
let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: int32,
vars: Default::default(),
}));
let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: 6,
fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect(),
params: Default::default(),
});
identifier_mapping.insert(
"Bar".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar,
vars: Default::default(),
})),
);
let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: 7,
fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect(),
params: Default::default(),
});
identifier_mapping.insert(
"Bar2".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar2,
vars: Default::default(),
})),
);
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
let id_to_name = [ let id_to_name = [
(0, "int32".to_string()), (0, "int32".to_string()),
(1, "int64".to_string()), (1, "int64".to_string()),
@ -101,12 +136,16 @@ impl TestEnvironment {
(3, "bool".to_string()), (3, "bool".to_string()),
(4, "none".to_string()), (4, "none".to_string()),
(5, "Foo".to_string()), (5, "Foo".to_string()),
(6, "Bar".to_string()),
(7, "Bar2".to_string()),
] ]
.iter() .iter()
.cloned() .cloned()
.collect(); .collect();
let resolver = Box::new(Resolver { identifier_mapping: identifier_mapping.clone() }) as Box<dyn SymbolResolver>; let resolver =
Box::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names })
as Box<dyn SymbolResolver>;
TestEnvironment { TestEnvironment {
unifier, unifier,
@ -115,6 +154,7 @@ impl TestEnvironment {
id_to_name, id_to_name,
identifier_mapping, identifier_mapping,
calls: Vec::new(), calls: Vec::new(),
virtual_checks: Vec::new(),
} }
} }
@ -125,7 +165,8 @@ impl TestEnvironment {
variable_mapping: Default::default(), variable_mapping: Default::default(),
calls: &mut self.calls, calls: &mut self.calls,
primitives: &mut self.primitives, primitives: &mut self.primitives,
return_type: None virtual_checks: &mut self.virtual_checks,
return_type: None,
} }
} }
} }
@ -136,7 +177,8 @@ impl TestEnvironment {
c = 1.234 c = 1.234
d = True d = True
"}, "},
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect() [("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
&[]
; "primitives test")] ; "primitives test")]
#[test_case(indoc! {" #[test_case(indoc! {"
a = lambda x, y: x a = lambda x, y: x
@ -144,7 +186,8 @@ impl TestEnvironment {
c = 1.234 c = 1.234
d = b(c) d = b(c)
"}, "},
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect() [("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
&[]
; "lambda test")] ; "lambda test")]
#[test_case(indoc! {" #[test_case(indoc! {"
a = lambda x: x a = lambda x: x
@ -160,20 +203,31 @@ impl TestEnvironment {
"}, "},
[("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"), [("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect() ("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
&[]
; "obj test")] ; "obj test")]
#[test_case(indoc! {" #[test_case(indoc! {"
f = lambda x: True f = lambda x: True
a = [1, 2, 3] a = [1, 2, 3]
b = [f(x) for x in a if f(x)] b = [f(x) for x in a if f(x)]
"}, "},
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect() [("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
&[]
; "listcomp test")] ; "listcomp test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>) { #[test_case(indoc! {"
a = virtual(Bar(), Bar)
b = a.b()
a = virtual(Bar2())
"},
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{}", source); println!("source:\n{}", source);
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap(); let statements = parse_program(source).unwrap();
let statements = statements let statements = statements
@ -201,4 +255,20 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
} }
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.stringify(
*a,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
let b = inferencer.unifier.stringify(
*b,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(&a, x);
assert_eq!(&b, y);
}
} }