1
0
forked from M-Labs/nac3
nac3/nac3core/src/toplevel/test.rs

838 lines
22 KiB
Rust
Raw Normal View History

use std::{collections::HashMap, sync::Arc};
use indoc::indoc;
use parking_lot::Mutex;
use test_case::test_case;
use nac3parser::{
ast::{fold::Fold, FileName},
parser::parse_program,
};
use super::*;
use crate::{
codegen::CodeGenContext,
2021-11-20 19:50:25 +08:00
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{helper::PrimDef, DefinitionId},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, Type, Unifier},
},
};
struct ResolverInternal {
2021-09-22 17:19:27 +08:00
id_to_type: Mutex<HashMap<StrRef, Type>>,
id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
class_names: Mutex<HashMap<StrRef, Type>>,
}
impl ResolverInternal {
2021-09-22 17:19:27 +08:00
fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.lock().insert(id, def);
}
2021-09-22 17:19:27 +08:00
fn add_id_type(&self, id: StrRef, ty: Type) {
self.id_to_type.lock().insert(id, ty);
}
}
struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver {
2022-02-21 18:27:46 +08:00
fn get_default_param_value(
&self,
2023-10-26 13:52:40 +08:00
_: &ast::Expr,
2022-02-21 18:27:46 +08:00
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
2022-02-12 21:21:56 +08:00
2021-10-16 18:08:13 +08:00
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
2022-01-13 03:21:26 +08:00
) -> Result<Type, String> {
self.0
.id_to_type
.lock()
.get(&str)
2024-06-17 14:07:38 +08:00
.copied()
.ok_or_else(|| format!("cannot find symbol `{str}`"))
}
2024-06-17 14:07:38 +08:00
fn get_symbol_value<'ctx>(
&self,
_: StrRef,
2024-06-17 14:07:38 +08:00
_: &mut CodeGenContext<'ctx, '_>,
2021-11-20 19:50:25 +08:00
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
2024-06-12 14:45:03 +08:00
self.0
.id_to_def
.lock()
.get(&id)
2024-06-17 14:07:38 +08:00
.copied()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}
2022-02-12 21:21:56 +08:00
fn get_string_id(&self, _: &str) -> i32 {
unimplemented!()
}
2022-03-05 00:27:51 +08:00
2022-03-26 18:52:08 +08:00
fn get_exception_id(&self, _tyid: usize) -> usize {
2022-03-05 00:27:51 +08:00
unimplemented!()
}
}
#[test_case(
vec![
indoc! {"
def fun(a: int32) -> int32:
return a
"},
indoc! {"
class A:
def __init__(self):
self.a: int32 = 3
"},
indoc! {"
class B:
def __init__(self):
self.b: float = 4.3
2021-09-22 17:19:27 +08:00
def fun(self):
self.b = self.b + 3.0
"},
indoc! {"
def foo(a: float):
a + 1.0
2021-08-31 17:40:38 +08:00
"},
indoc! {"
class C(B):
def __init__(self):
self.c: int32 = 4
2021-08-31 17:40:38 +08:00
self.a: bool = True
"},
];
"register"
)]
fn test_simple_register(source: Vec<&str>) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
for s in source {
2024-06-17 14:07:38 +08:00
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
2024-06-17 14:07:38 +08:00
composer.register_top_level(ast, None, "", false).unwrap();
}
}
#[test_case(
indoc! {"
class A:
def foo(self):
pass
a = A()
"};
"register"
)]
fn test_simple_register_without_constructor(source: &str) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
2024-06-17 14:07:38 +08:00
let ast = parse_program(source, FileName::default()).unwrap();
let ast = ast[0].clone();
2024-06-17 14:07:38 +08:00
composer.register_top_level(ast, None, "", true).unwrap();
}
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
def fun(a: int32) -> int32:
return a
"},
indoc! {"
def foo(a: float):
a + 1.0
"},
indoc! {"
def f(b: int64) -> int32:
return 3
"},
],
2024-06-17 14:07:38 +08:00
&[
"fn[[a:0], 0]",
"fn[[a:2], 4]",
"fn[[b:1], 0]",
],
2024-06-17 14:07:38 +08:00
&[
"fun",
"foo",
"f"
];
"function compose"
)]
2024-06-17 14:07:38 +08:00
fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal {
2024-06-17 14:07:38 +08:00
id_to_def: Mutex::default(),
id_to_type: Mutex::default(),
class_names: Mutex::default(),
});
2021-10-16 18:08:13 +08:00
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
2024-06-17 14:07:38 +08:00
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) =
2024-06-17 14:07:38 +08:00
composer.register_top_level(ast, Some(resolver.clone()), "", false).unwrap();
2021-10-16 18:08:13 +08:00
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
composer.start_analysis(true).unwrap();
2021-12-02 10:45:46 +08:00
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
let def = &*def.read();
if let TopLevelDef::Function { signature, name, .. } = def {
2022-02-21 18:27:46 +08:00
let ty_str = composer.unifier.internal_stringify(
*signature,
&mut |id| id.to_string(),
&mut |id| id.to_string(),
&mut None,
);
assert_eq!(ty_str, tys[i]);
assert_eq!(name, names[i]);
}
}
}
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A():
a: int32
def __init__(self):
self.a = 3
def fun(self, b: B):
pass
def foo(self, a: T, b: V):
pass
"},
indoc! {"
class B(C):
def __init__(self):
pass
"},
indoc! {"
class C(A):
def __init__(self):
pass
def fun(self, b: B):
a = 1
pass
"},
indoc! {"
def foo(a: A):
pass
"},
indoc! {"
def ff(a: T) -> V:
pass
"}
],
2024-06-17 14:07:38 +08:00
&[];
"simple class compose"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class Generic_A(Generic[V], B):
a: int64
def __init__(self):
self.a = 123123123123
def fun(self, a: int32) -> V:
pass
"},
indoc! {"
class B:
aa: bool
def __init__(self):
self.aa = False
def foo(self, b: T):
pass
"}
],
2024-06-17 14:07:38 +08:00
&[];
"generic class"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]:
pass
"},
indoc! {"
class A(Generic[T, V]):
a: T
b: V
def __init__(self, v: V):
self.a = 1
self.b = v
def fun(self, a: T) -> V:
pass
"},
indoc! {"
def gfun(a: A[list[float], int32]):
pass
"},
indoc! {"
class B:
def __init__(self):
pass
"}
],
2024-06-17 14:07:38 +08:00
&[];
"list tuple generic"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A(Generic[T, V]):
a: A[float, bool]
b: B
def __init__(self, a: A[float, bool], b: B):
self.a = a
self.b = b
def fun(self, a: A[float, bool]) -> A[bool, int32]:
pass
"},
indoc! {"
class B(A[int64, bool]):
def __init__(self):
pass
def foo(self, b: B) -> B:
pass
def bar(self, a: A[list[B], int32]) -> tuple[A[virtual[A[B, int32]], bool], B]:
pass
"}
],
2024-06-17 14:07:38 +08:00
&[];
"self1"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
def foo(self, c: C):
pass
"},
indoc! {"
class B(Generic[V], A[float]):
d: C
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
# override
pass
"},
indoc! {"
class C(B[bool]):
e: int64
def __init__(self):
pass
"}
],
2024-06-17 14:07:38 +08:00
&[];
"inheritance_override"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A(Generic[T]):
def __init__(self):
pass
def fun(self, a: A[T]) -> A[T]:
pass
"}
],
2024-06-17 14:07:38 +08:00
&["application of type vars to generic class is not currently supported (at unknown:4:24)"];
"err no type var in generic app"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A(B):
def __init__(self):
pass
"},
indoc! {"
class B(A):
def __init__(self):
pass
"}
],
2024-06-17 14:07:38 +08:00
&["cyclic inheritance detected"];
"cyclic1"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A(B[bool, int64]):
def __init__(self):
pass
"},
indoc! {"
class B(Generic[V, T], C[int32]):
def __init__(self):
pass
"},
indoc! {"
class C(Generic[T], A):
def __init__(self):
pass
"},
],
2024-06-17 14:07:38 +08:00
&["cyclic inheritance detected"];
"cyclic2"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A:
pass
"}
],
2024-06-17 14:07:38 +08:00
&["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
"simple pass in class"
)]
2021-09-10 21:26:39 +08:00
#[test_case(
2024-06-17 14:07:38 +08:00
&[indoc! {"
2021-09-10 21:26:39 +08:00
class A:
def __init__():
pass
"}],
2024-06-17 14:07:38 +08:00
&["__init__ method must have a `self` parameter (at unknown:2:5)"];
2021-09-10 21:26:39 +08:00
"err no self_1"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
2021-09-10 21:26:39 +08:00
indoc! {"
class A(B, Generic[T], C):
def __init__(self):
pass
"},
indoc! {"
class B:
def __init__(self):
pass
"},
indoc! {"
class C:
def __init__(self):
pass
"}
],
2024-06-17 14:07:38 +08:00
&["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
2021-09-10 21:26:39 +08:00
"err multiple inheritance"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
"},
indoc! {"
class B(Generic[V], A[float]):
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[int32]]]:
# override
pass
"}
],
2024-06-17 14:07:38 +08:00
&["method fun has same name as ancestors' method, but incompatible type"];
"err_incompatible_inheritance_method"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A(Generic[T]):
a: int32
b: T
c: A[int64]
def __init__(self, t: T):
self.a = 3
self.b = T
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
pass
"},
indoc! {"
class B(Generic[V], A[float]):
a: int32
def __init__(self):
pass
def fun(self, a: int32, b: T) -> list[virtual[B[bool]]]:
# override
pass
"}
],
2024-06-17 14:07:38 +08:00
&["field `a` has already declared in the ancestor classes"];
"err_incompatible_inheritance_field"
)]
#[test_case(
2024-06-17 14:07:38 +08:00
&[
indoc! {"
class A:
def __init__(self):
pass
"},
indoc! {"
class A:
a: int32
def __init__(self):
pass
"}
],
2024-06-17 14:07:38 +08:00
&["duplicate definition of class `A` (at unknown:1:1)"];
"class same name"
)]
2024-06-17 14:07:38 +08:00
fn test_analyze(source: &[&str], res: &[&str]) {
2021-09-10 21:26:39 +08:00
let print = false;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
("T".into(), vec![]),
("V".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int32]),
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
],
&mut composer.unifier,
print,
);
2021-10-16 18:08:13 +08:00
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
2024-06-17 14:07:38 +08:00
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
2024-06-17 14:07:38 +08:00
match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
2021-09-10 21:26:39 +08:00
Ok(x) => x,
Err(msg) => {
if print {
2024-06-17 14:07:38 +08:00
println!("{msg}");
2021-09-10 21:26:39 +08:00
} else {
assert_eq!(res[0], msg);
}
2021-09-12 04:40:40 +08:00
return;
2021-09-10 21:26:39 +08:00
}
}
};
2021-09-22 17:19:27 +08:00
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
if let Err(msg) = composer.start_analysis(false) {
2021-09-10 21:26:39 +08:00
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
2021-09-10 21:26:39 +08:00
} else {
assert_eq!(res[0], msg.iter().next().unwrap());
2021-09-10 21:26:39 +08:00
}
} else {
// skip 5 to skip primitives
let mut res_vec: Vec<String> = Vec::new();
2021-12-02 10:45:46 +08:00
for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
2021-09-10 21:26:39 +08:00
let def = &*def.read();
res_vec.push(format!("{}\n", def.to_string(composer.unifier.borrow_mut())));
2021-09-10 21:26:39 +08:00
}
insta::assert_debug_snapshot!(res_vec);
}
}
#[test_case(
vec![
indoc! {"
def fun(a: int32, b: int32) -> int32:
return a + b
"},
indoc! {"
def fib(n: int32) -> int32:
if n <= 2:
return 1
a = fib(n - 1)
b = fib(n - 2)
return fib(n - 1)
"}
],
2024-06-17 14:07:38 +08:00
&[];
"simple function"
)]
#[test_case(
vec![
indoc! {"
class A:
a: int32
def __init__(self):
self.a = 3
def fun(self) -> int32:
b = self.a + 3
return b * self.a
def clone(self) -> A:
SELF = self
return SELF
def sum(self) -> int32:
if self.a == 0:
return self.a
else:
a = self.a
self.a = self.a - 1
return a + self.sum()
def fib(self, a: int32) -> int32:
if a <= 2:
return 1
return self.fib(a - 1) + self.fib(a - 2)
"},
indoc! {"
def fun(a: A) -> int32:
return a.fun() + 2
"}
],
2024-06-17 14:07:38 +08:00
&[];
"simple class body"
)]
2021-09-17 00:35:58 +08:00
#[test_case(
vec![
indoc! {"
def fun(a: V, c: G, t: T) -> V:
b = a
cc = c
ret = fun(b, cc, t)
return ret * ret
"},
indoc! {"
def sum_three(l: list[V]) -> V:
return l[0] + l[1] + l[2]
"},
indoc! {"
def sum_sq_pair(p: tuple[V, V]) -> list[V]:
a = p[0]
b = p[1]
a = a**a
b = b**b
return [a, b]
2021-09-17 00:35:58 +08:00
"}
],
2024-06-17 14:07:38 +08:00
&[];
2021-09-17 00:35:58 +08:00
"type var fun"
)]
#[test_case(
vec![
indoc! {"
class A(Generic[G]):
a: G
b: bool
def __init__(self, aa: G):
self.a = aa
if 2 > 1:
self.b = True
else:
# self.b = False
pass
def fun(self, a: G) -> list[G]:
ret = [a, self.a]
return ret if self.b else self.fun(self.a)
"}
],
2024-06-17 14:07:38 +08:00
&[];
"type var class"
)]
#[test_case(
vec![
indoc! {"
class A:
def fun(self):
pass
"},
indoc!{"
class B:
a: int32
b: bool
def __init__(self):
# self.b = False
if 3 > 2:
self.a = 3
self.b = False
else:
self.a = 4
self.b = True
"}
],
2024-06-17 14:07:38 +08:00
&[];
"no_init_inst_check"
)]
2024-06-17 14:07:38 +08:00
fn test_inference(source: Vec<&str>, res: &[&str]) {
let print = true;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
("T".into(), vec![]),
(
"V".into(),
vec![
composer.primitives_ty.float,
composer.primitives_ty.int32,
composer.primitives_ty.int64,
],
),
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
],
&mut composer.unifier,
print,
);
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
2024-06-17 14:07:38 +08:00
let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
2024-06-17 14:07:38 +08:00
match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
Ok(x) => x,
Err(msg) => {
if print {
2024-06-17 14:07:38 +08:00
println!("{msg}");
} else {
assert_eq!(res[0], msg);
}
return;
}
}
};
2021-10-16 18:08:13 +08:00
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
}
if let Err(msg) = composer.start_analysis(true) {
if print {
println!("{}", msg.iter().sorted().join("\n----------\n"));
} else {
assert_eq!(res[0], msg.iter().next().unwrap());
}
} else {
// skip 5 to skip primitives
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier };
2024-06-17 14:07:38 +08:00
for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
println!(
"=========`{}`: number of instances: {}===========",
name,
instance_to_stmt.len()
);
2024-06-17 14:07:38 +08:00
for inst in instance_to_stmt {
let ast = &inst.1.body;
2021-09-22 17:19:27 +08:00
for b in ast.iter() {
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
println!("--------------------");
}
println!("\n");
}
}
}
}
}
fn make_internal_resolver_with_tvar(
2021-09-22 17:19:27 +08:00
tvars: Vec<(StrRef, Vec<Type>)>,
unifier: &mut Unifier,
print: bool,
) -> Arc<ResolverInternal> {
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let res: Arc<ResolverInternal> = ResolverInternal {
id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])),
id_to_type: tvars
.into_iter()
.map(|(name, range)| {
2021-10-16 18:08:13 +08:00
(name, {
let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
if print {
println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id);
}
tvar.ty
})
})
.collect::<HashMap<_, _>>()
.into(),
class_names: Mutex::new(HashMap::from([("list".into(), list)])),
}
.into();
if print {
println!();
}
res
}
struct TypeToStringFolder<'a> {
unifier: &'a mut Unifier,
}
impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
type TargetU = String;
type Error = String;
fn map_user(&mut self, user: Option<Type>) -> Result<Self::TargetU, Self::Error> {
Ok(if let Some(ty) = user {
2022-02-21 18:27:46 +08:00
self.unifier.internal_stringify(
ty,
2024-06-17 14:07:38 +08:00
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
2022-02-21 18:27:46 +08:00
&mut None,
)
} else {
"None".into()
})
}
}