nac3/nac3core/src/codegen/test.rs

350 lines
11 KiB
Rust

use crate::{
codegen::{CodeGenTask, WithCall, WorkerRegistry},
location::Location,
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{DefinitionId, FunInstance, TopLevelComposer, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
};
use indoc::indoc;
use parking_lot::RwLock;
use rustpython_parser::{ast::fold::Fold, parser::parse_program};
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
struct Resolver {
id_to_type: HashMap<String, Type>,
id_to_def: RwLock<HashMap<String, DefinitionId>>,
class_names: HashMap<String, Type>,
}
impl Resolver {
pub fn add_id_def(&self, id: String, def: DefinitionId) {
self.id_to_def.write().insert(id, def);
}
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.id_to_type.get(str).cloned()
}
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&self, _: &str) -> Option<Location> {
unimplemented!()
}
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
self.id_to_def.read().get(id).cloned()
}
}
#[test]
fn test_primitives() {
let source = indoc! { "
c = a + b
d = a if c == 1 else 0
return d
"};
let statements = parse_program(source).unwrap();
let composer = TopLevelComposer::new();
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let resolver = Arc::new(Box::new(Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(),
}) as Box<dyn SymbolResolver + Send + Sync>);
let threads = ["test"];
let signature = FunSignature {
args: vec![
FuncArg { name: "a".to_string(), ty: primitives.int32, default_value: None },
FuncArg { name: "b".to_string(), ty: primitives.int32, default_value: None },
],
ret: primitives.int32,
vars: HashMap::new(),
};
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".to_string(), "b".to_string()].iter().cloned().collect();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
});
let unifier = (unifier.get_shared_unifier(), primitives);
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".to_string(),
body: statements,
resolver,
unifier,
calls,
signature,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
// the following IR is equivalent to
// ```
// ; ModuleID = 'test.ll'
// source_filename = "test"
//
// ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init:
// %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr
// }
//
// attributes #0 = { norecurse nounwind readnone }
// ```
// after O2 optimization
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
define i32 @testing(i32 %0, i32 %1) {
init:
%a = alloca i32, align 4
store i32 %0, i32* %a, align 4
%b = alloca i32, align 4
store i32 %1, i32* %b, align 4
%tmp = alloca i32, align 4
%tmp4 = alloca i32, align 4
br label %body
body: ; preds = %init
%load = load i32, i32* %a, align 4
%load1 = load i32, i32* %b, align 4
%add = add i32 %load, %load1
store i32 %add, i32* %tmp, align 4
%load2 = load i32, i32* %tmp, align 4
%cmp = icmp eq i32 %load2, 1
br i1 %cmp, label %then, label %else
then: ; preds = %body
%load3 = load i32, i32* %a, align 4
br label %cont
else: ; preds = %body
br label %cont
cont: ; preds = %else, %then
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
store i32 %ifexpr, i32* %tmp4, align 4
%load5 = load i32, i32* %tmp4, align 4
ret i32 %load5
}
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}
#[test]
fn test_simple_call() {
let source_1 = indoc! { "
a = foo(a)
return a * 2
"};
let statements_1 = parse_program(source_1).unwrap();
let source_2 = indoc! { "
return a + 1
"};
let statements_2 = parse_program(source_2).unwrap();
let composer = TopLevelComposer::new();
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let signature = FunSignature {
args: vec![FuncArg { name: "a".to_string(), ty: primitives.int32, default_value: None }],
ret: primitives.int32,
vars: HashMap::new(),
};
let fun_ty = unifier.add_ty(TypeEnum::TFunc(RefCell::new(signature.clone())));
let foo_id = top_level.definitions.read().len();
top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function {
name: "foo".to_string(),
signature: fun_ty,
var_id: vec![],
instance_to_stmt: HashMap::new(),
instance_to_symbol: HashMap::new(),
resolver: None,
})));
let resolver = Box::new(Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(),
});
resolver.add_id_def("foo".to_string(), DefinitionId(foo_id));
let resolver = Arc::new(resolver as Box<dyn SymbolResolver + Send + Sync>);
if let TopLevelDef::Function { resolver: r, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
*r = Some(resolver.clone());
} else {
unreachable!()
}
let threads = ["test"];
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".to_string(), "foo".into()].iter().cloned().collect();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("foo".into(), fun_ty);
let statements_1 = statements_1
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let calls1 = inferencer.calls.clone();
inferencer.calls.clear();
let statements_2 = statements_2
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
if let TopLevelDef::Function { instance_to_stmt, .. } =
&mut *top_level.definitions.read()[foo_id].write()
{
instance_to_stmt.insert(
"".to_string(),
FunInstance {
body: statements_2,
calls: inferencer.calls.clone(),
subst: Default::default(),
unifier_id: 0,
},
);
} else {
unreachable!()
}
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
});
let unifier = (unifier.get_shared_unifier(), primitives);
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".to_string(),
body: statements_1,
resolver,
unifier,
calls: calls1,
signature,
};
let f = Arc::new(WithCall::new(Box::new(|module| {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
define i32 @testing(i32 %0) {
init:
%a = alloca i32, align 4
store i32 %0, i32* %a, align 4
br label %body
body: ; preds = %init
%load = load i32, i32* %a, align 4
%call = call i32 @foo_0(i32 %load)
store i32 %call, i32* %a, align 4
%load1 = load i32, i32* %a, align 4
%mul = mul i32 %load1, 2
ret i32 %mul
}
declare i32 @foo_0(i32)
define i32 @foo_0.1(i32 %0) {
init:
%a = alloca i32, align 4
store i32 %0, i32* %a, align 4
br label %body
body: ; preds = %init
%load = load i32, i32* %a, align 4
%add = add i32 %load, 1
ret i32 %add
}
"}
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}