nac3/nac3core/src/typecheck/type_inferencer/test.rs

282 lines
8.6 KiB
Rust
Raw Normal View History

2021-07-21 15:36:35 +08:00
use super::super::location::Location;
use super::super::symbol_resolver::*;
use super::super::typedef::*;
use super::*;
use indoc::indoc;
2021-07-27 11:58:35 +08:00
use itertools::zip;
2021-07-21 15:36:35 +08:00
use rustpython_parser::ast;
use rustpython_parser::parser::parse_program;
use test_case::test_case;
struct Resolver {
2021-07-22 17:07:49 +08:00
identifier_mapping: HashMap<String, Type>,
2021-07-27 11:58:35 +08:00
class_names: HashMap<String, Type>,
2021-07-21 15:36:35 +08:00
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&mut self, str: &str) -> Option<Type> {
2021-07-22 17:07:49 +08:00
self.identifier_mapping.get(str).cloned()
2021-07-21 15:36:35 +08:00
}
2021-07-27 11:58:35 +08:00
fn parse_type_name(&mut self, ty: &ast::Expr<()>) -> Option<Type> {
if let ExprKind::Name { id, .. } = &ty.node {
self.class_names.get(id).cloned()
} else {
unimplemented!()
}
2021-07-21 15:36:35 +08:00
}
fn get_symbol_value(&mut self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&mut self, _: &str) -> Option<Location> {
unimplemented!()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub resolver: Box<dyn SymbolResolver>,
pub calls: Vec<Rc<Call>>,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
2021-07-22 17:07:49 +08:00
pub identifier_mapping: HashMap<String, Type>,
2021-07-27 11:58:35 +08:00
pub virtual_checks: Vec<(Type, Type)>,
2021-07-21 15:36:35 +08:00
}
impl TestEnvironment {
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
2021-07-22 17:07:49 +08:00
let mut identifier_mapping = HashMap::new();
2021-07-21 15:36:35 +08:00
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: 0,
fields: HashMap::new(),
params: HashMap::new(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: 1,
fields: HashMap::new(),
params: HashMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: 2,
fields: HashMap::new(),
params: HashMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: 3,
fields: HashMap::new(),
params: HashMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: 4,
fields: HashMap::new(),
params: HashMap::new(),
});
2021-07-22 17:07:49 +08:00
identifier_mapping.insert("None".into(), none);
2021-07-21 15:36:35 +08:00
2021-07-27 11:58:35 +08:00
let primitives = PrimitiveStore { int32, int64, float, bool, none };
2021-07-21 15:36:35 +08:00
let (v0, id) = unifier.get_fresh_var();
2021-07-21 15:59:01 +08:00
let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: 5,
fields: [("a".into(), v0)].iter().cloned().collect(),
params: [(id, v0)].iter().cloned().collect(),
});
2021-07-22 17:07:49 +08:00
identifier_mapping.insert(
2021-07-21 15:59:01 +08:00
"Foo".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(),
})),
2021-07-21 15:36:35 +08:00
);
2021-07-27 11:58:35 +08:00
let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: int32,
vars: Default::default(),
}));
let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: 6,
fields: [("a".into(), int32), ("b".into(), fun)].iter().cloned().collect(),
params: Default::default(),
});
identifier_mapping.insert(
"Bar".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar,
vars: Default::default(),
})),
);
let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: 7,
fields: [("a".into(), bool), ("b".into(), fun)].iter().cloned().collect(),
params: Default::default(),
});
identifier_mapping.insert(
"Bar2".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar2,
vars: Default::default(),
})),
);
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
2021-07-21 15:36:35 +08:00
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
(5, "Foo".to_string()),
2021-07-27 11:58:35 +08:00
(6, "Bar".to_string()),
(7, "Bar2".to_string()),
2021-07-21 15:36:35 +08:00
]
.iter()
.cloned()
.collect();
2021-07-27 11:58:35 +08:00
let resolver =
Box::new(Resolver { identifier_mapping: identifier_mapping.clone(), class_names })
as Box<dyn SymbolResolver>;
2021-07-21 15:36:35 +08:00
TestEnvironment {
unifier,
resolver,
primitives,
id_to_name,
2021-07-22 17:07:49 +08:00
identifier_mapping,
2021-07-21 15:36:35 +08:00
calls: Vec::new(),
2021-07-27 11:58:35 +08:00
virtual_checks: Vec::new(),
2021-07-21 15:36:35 +08:00
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
resolver: &mut self.resolver,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
calls: &mut self.calls,
primitives: &mut self.primitives,
2021-07-27 11:58:35 +08:00
virtual_checks: &mut self.virtual_checks,
return_type: None,
2021-07-21 15:36:35 +08:00
}
}
}
#[test_case(indoc! {"
a = 1234
b = int64(2147483648)
c = 1.234
d = True
"},
2021-07-27 11:58:35 +08:00
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
&[]
2021-07-21 15:36:35 +08:00
; "primitives test")]
#[test_case(indoc! {"
a = lambda x, y: x
b = lambda x: a(x, x)
c = 1.234
d = b(c)
"},
2021-07-27 11:58:35 +08:00
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
&[]
2021-07-21 15:36:35 +08:00
; "lambda test")]
2021-07-21 15:59:01 +08:00
#[test_case(indoc! {"
a = lambda x: x
b = lambda x: x
foo1 = Foo()
foo2 = Foo()
c = a(foo1.a)
d = b(foo2.a)
a(True)
b(123)
"},
[("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"),
2021-07-27 11:58:35 +08:00
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
&[]
2021-07-21 15:59:01 +08:00
; "obj test")]
2021-07-21 16:06:06 +08:00
#[test_case(indoc! {"
f = lambda x: True
a = [1, 2, 3]
2021-07-21 16:10:11 +08:00
b = [f(x) for x in a if f(x)]
2021-07-21 16:06:06 +08:00
"},
2021-07-27 11:58:35 +08:00
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect(),
&[]
2021-07-21 16:06:06 +08:00
; "listcomp test")]
2021-07-27 11:58:35 +08:00
#[test_case(indoc! {"
a = virtual(Bar(), Bar)
b = a.b()
a = virtual(Bar2())
"},
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")]
2021-07-27 14:39:53 +08:00
#[test_case(indoc! {"
a = [virtual(Bar(), Bar), virtual(Bar2())]
b = [x.b() for x in a]
"},
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual list test")]
2021-07-27 11:58:35 +08:00
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
2021-07-22 17:07:49 +08:00
println!("source:\n{}", source);
2021-07-21 15:36:35 +08:00
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
2021-07-27 11:58:35 +08:00
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
2021-07-21 15:36:35 +08:00
let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap();
2021-07-22 17:07:49 +08:00
let statements = statements
2021-07-21 15:36:35 +08:00
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
2021-07-22 17:07:49 +08:00
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
2021-07-21 15:59:01 +08:00
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify(
*v,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
println!("{}: {}", k, name);
}
2021-07-21 15:36:35 +08:00
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap();
let name = inferencer.unifier.stringify(
*ty,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
2021-07-27 11:58:35 +08:00
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.stringify(
*a,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
let b = inferencer.unifier.stringify(
*b,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(&a, x);
assert_eq!(&b, y);
}
2021-07-21 15:36:35 +08:00
}