1
0
forked from M-Labs/nac3
nac3/nac3core/src/codegen/test.rs

450 lines
16 KiB
Rust
Raw Normal View History

2021-09-30 17:07:48 +08:00
use crate::{
codegen::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
2024-06-12 14:45:03 +08:00
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
2021-11-20 19:50:25 +08:00
symbol_resolver::{SymbolResolver, ValueEnum},
2021-09-30 17:07:48 +08:00
toplevel::{
2024-06-12 14:45:03 +08:00
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
2021-09-30 17:07:48 +08:00
},
typecheck::{
2021-08-25 15:30:36 +08:00
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
2021-09-30 17:07:48 +08:00
},
};
2021-08-12 13:55:15 +08:00
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
2024-06-12 14:45:03 +08:00
OptimizationLevel,
};
2021-11-03 17:11:00 +08:00
use nac3parser::{
2021-09-30 17:07:48 +08:00
ast::{fold::Fold, StrRef},
parser::parse_program,
};
2021-11-20 19:50:25 +08:00
use parking_lot::RwLock;
2021-08-27 13:04:51 +08:00
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
2021-08-12 13:55:15 +08:00
struct Resolver {
2021-09-22 17:19:27 +08:00
id_to_type: HashMap<StrRef, Type>,
id_to_def: RwLock<HashMap<StrRef, DefinitionId>>,
class_names: HashMap<StrRef, Type>,
2021-08-12 13:55:15 +08:00
}
impl Resolver {
2021-09-22 17:19:27 +08:00
pub fn add_id_def(&self, id: StrRef, def: DefinitionId) {
self.id_to_def.write().insert(id, def);
}
}
2021-08-12 13:55:15 +08:00
impl SymbolResolver for Resolver {
2022-02-21 18:27:46 +08:00
fn get_default_param_value(
&self,
_: &nac3parser::ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type(
&self,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
_: &PrimitiveStore,
str: StrRef,
2022-01-13 03:21:26 +08:00
) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str))
2021-08-12 13:55:15 +08:00
}
fn get_symbol_value<'ctx, 'a>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, 'a>,
2021-11-20 19:50:25 +08:00
) -> Option<ValueEnum<'ctx>> {
2021-08-12 13:55:15 +08:00
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
2022-02-21 18:27:46 +08:00
self.id_to_def
.read()
.get(&id)
.cloned()
2024-06-12 14:45:03 +08:00
.ok_or_else(|| HashSet::from([format!("cannot find symbol `{}`", id)]))
2021-08-12 13:55:15 +08:00
}
2022-02-12 21:21:56 +08:00
fn get_string_id(&self, _: &str) -> i32 {
unimplemented!()
}
2022-03-05 00:27:51 +08:00
2022-03-26 18:52:08 +08:00
fn get_exception_id(&self, _tyid: usize) -> usize {
2022-03-05 00:27:51 +08:00
unimplemented!()
}
2021-08-12 13:55:15 +08:00
}
2021-08-25 15:30:36 +08:00
#[test]
fn test_primitives() {
let source = indoc! { "
c = a + b
d = a if c == 1 else 0
return d
"};
2021-12-28 01:38:16 +08:00
let statements = parse_program(source, Default::default()).unwrap();
2021-08-12 13:55:15 +08:00
2024-01-17 09:48:37 +08:00
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
2021-08-25 15:30:36 +08:00
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
2021-08-12 13:55:15 +08:00
2021-10-16 18:08:13 +08:00
let resolver = Arc::new(Resolver {
2021-08-25 15:30:36 +08:00
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
2021-08-25 15:30:36 +08:00
class_names: Default::default(),
2021-10-16 18:08:13 +08:00
}) as Arc<dyn SymbolResolver + Send + Sync>;
2021-08-12 13:55:15 +08:00
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
2021-08-12 13:55:15 +08:00
let signature = FunSignature {
args: vec![
2021-09-22 17:19:27 +08:00
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
2021-08-12 13:55:15 +08:00
],
2021-08-25 15:30:36 +08:00
ret: primitives.int32,
vars: VarMap::new(),
2021-08-12 13:55:15 +08:00
};
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut unifier, &primitives, &signature, &mut cache);
let signature = store.add_cty(signature);
2021-08-25 15:30:36 +08:00
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
2021-09-22 17:19:27 +08:00
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect();
2021-08-25 15:30:36 +08:00
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
2021-08-27 13:04:51 +08:00
defined_identifiers: identifiers.clone(),
2022-02-21 18:27:46 +08:00
in_handler: false,
2021-08-25 15:30:36 +08:00
};
2021-08-19 11:32:22 +08:00
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
2021-08-12 13:55:15 +08:00
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
2021-08-25 15:30:36 +08:00
2021-08-19 11:32:22 +08:00
inferencer.check_block(&statements, &mut identifiers).unwrap();
2021-08-12 13:55:15 +08:00
let top_level = Arc::new(TopLevelContext {
2021-08-25 15:30:36 +08:00
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
2021-09-30 17:07:48 +08:00
personality_symbol: None,
2021-08-12 13:55:15 +08:00
});
2021-08-25 15:30:36 +08:00
2021-08-12 13:55:15 +08:00
let task = CodeGenTask {
subst: Default::default(),
2021-09-22 17:19:27 +08:00
symbol_name: "testing".into(),
body: Arc::new(statements),
unifier_index: 0,
2021-09-22 17:19:27 +08:00
calls: Arc::new(calls),
resolver,
store,
2021-08-12 13:55:15 +08:00
signature,
2021-11-20 19:50:25 +08:00
id: 0,
2021-08-12 13:55:15 +08:00
};
let f = Arc::new(WithCall::new(Box::new(|module| {
// the following IR is equivalent to
// ```
// ; ModuleID = 'test.ll'
// source_filename = "test"
//
// ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init:
// %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr
// }
//
// attributes #0 = { norecurse nounwind readnone }
// ```
// after O2 optimization
let expected = indoc! {"
2021-08-19 11:32:22 +08:00
; ModuleID = 'test'
source_filename = \"test\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
2021-08-19 11:32:22 +08:00
init:
%add = add i32 %1, %0, !dbg !9
2022-04-16 03:00:46 +08:00
%cmp = icmp eq i32 %add, 1, !dbg !10
%. = select i1 %cmp, i32 %0, i32 0, !dbg !11
ret i32 %., !dbg !12
2021-08-19 11:32:22 +08:00
}
2022-04-16 03:00:46 +08:00
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
2022-04-16 03:00:46 +08:00
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !5, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !8)
!5 = !DISubroutineType(flags: DIFlagPublic, types: !6)
!6 = !{!7}
!7 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!8 = !{}
!9 = !DILocation(line: 1, column: 9, scope: !4)
!10 = !DILocation(line: 2, column: 15, scope: !4)
!11 = !DILocation(line: 0, scope: !4)
!12 = !DILocation(line: 3, column: 8, scope: !4)
2022-04-16 03:00:46 +08:00
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
2024-06-12 14:45:03 +08:00
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
2021-08-13 16:20:14 +08:00
registry.wait_tasks_complete(handles);
}
#[test]
fn test_simple_call() {
let source_1 = indoc! { "
a = foo(a)
return a * 2
"};
2021-12-28 01:38:16 +08:00
let statements_1 = parse_program(source_1, Default::default()).unwrap();
let source_2 = indoc! { "
return a + 1
"};
2021-12-28 01:38:16 +08:00
let statements_2 = parse_program(source_2, Default::default()).unwrap();
2024-01-17 09:48:37 +08:00
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let signature = FunSignature {
2021-09-22 17:19:27 +08:00
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
ret: primitives.int32,
vars: VarMap::new(),
};
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(&mut unifier, &primitives, &signature, &mut cache);
let signature = store.add_cty(signature);
let foo_id = top_level.definitions.read().len();
top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function {
name: "foo".to_string(),
2021-09-22 17:19:27 +08:00
simple_name: "foo".into(),
signature: fun_ty,
var_id: vec![],
instance_to_stmt: HashMap::new(),
instance_to_symbol: HashMap::new(),
2021-08-27 13:04:51 +08:00
resolver: None,
2021-09-30 17:07:48 +08:00
codegen_callback: None,
loc: None,
})));
2021-10-16 18:08:13 +08:00
let resolver = Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(),
2021-10-16 18:08:13 +08:00
};
2021-09-22 17:19:27 +08:00
resolver.add_id_def("foo".into(), DefinitionId(foo_id));
2021-10-16 18:08:13 +08:00
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
2021-08-27 13:04:51 +08:00
if let TopLevelDef::Function { resolver: r, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
*r = Some(resolver.clone());
} else {
unreachable!()
}
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
2021-09-22 17:19:27 +08:00
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
2021-08-27 13:04:51 +08:00
defined_identifiers: identifiers.clone(),
2022-02-21 18:27:46 +08:00
in_handler: false,
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("foo".into(), fun_ty);
let statements_1 = statements_1
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let calls1 = inferencer.calls.clone();
inferencer.calls.clear();
let statements_2 = statements_2
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
2021-08-27 13:04:51 +08:00
if let TopLevelDef::Function { instance_to_stmt, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
instance_to_stmt.insert(
"".to_string(),
FunInstance {
2021-09-22 17:19:27 +08:00
body: Arc::new(statements_2),
calls: Arc::new(inferencer.calls.clone()),
2021-08-27 13:04:51 +08:00
subst: Default::default(),
unifier_id: 0,
},
);
} else {
unreachable!()
}
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
2021-09-30 17:07:48 +08:00
personality_symbol: None,
});
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".to_string(),
2021-09-22 17:19:27 +08:00
body: Arc::new(statements_1),
calls: Arc::new(calls1),
unifier_index: 0,
resolver,
signature,
store,
2021-11-20 19:50:25 +08:00
id: 0,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {
init:
%add.i = shl i32 %0, 1, !dbg !10
%mul = add i32 %add.i, 2, !dbg !10
ret i32 %mul, !dbg !10
}
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @foo.0(i32 %0) local_unnamed_addr #0 !dbg !11 {
init:
%add = add i32 %0, 1, !dbg !12
ret i32 %add, !dbg !12
}
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
2022-04-16 03:00:46 +08:00
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2, !4}
2022-04-16 03:00:46 +08:00
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!5 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !9)
!6 = !DISubroutineType(flags: DIFlagPublic, types: !7)
!7 = !{!8}
!8 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!9 = !{}
!10 = !DILocation(line: 2, column: 12, scope: !5)
!11 = distinct !DISubprogram(name: \"foo.0\", linkageName: \"foo.0\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !4, retainedNodes: !9)
!12 = !DILocation(line: 1, column: 12, scope: !11)
2022-04-16 03:00:46 +08:00
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
2024-06-12 14:45:03 +08:00
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
2021-08-12 13:55:15 +08:00
}
#[test]
fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok());
}
#[test]
fn test_classes_range_type_new() {
let ctx = inkwell::context::Context::create();
let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
}
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}